TH1/AITrainPython/test_model_consistency.py
2025-12-06 11:44:43 +08:00

575 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型一致性测试脚本
测试 PTH 和 ONNX 模型的输出是否一致,以及概率分布是否合理
"""
import os
import json
import numpy as np
import torch
import onnxruntime as ort
from ppo_model import PPONetwork
from data_loader import TrainingDataLoader
from colorama import init, Fore, Style
import sys
# 初始化 colorama用于彩色输出
init(autoreset=True)
class ModelConsistencyTester:
"""模型一致性测试器"""
def __init__(self, pth_path='models/ppo_model_best.pth',
onnx_path='models/ppo_model_best.onnx',
state_dim=572):
"""
Args:
pth_path: PyTorch模型路径
onnx_path: ONNX模型路径
state_dim: 状态维度
"""
self.pth_path = pth_path
self.onnx_path = onnx_path
self.state_dim = state_dim
self.device = torch.device('cpu') # 统一使用CPU进行比较
print(f"{Fore.CYAN}{'='*80}")
print(f"{Fore.CYAN}模型一致性测试")
print(f"{Fore.CYAN}{'='*80}\n")
# 加载模型
self._load_pth_model()
self._load_onnx_model()
def _load_pth_model(self):
"""加载PyTorch模型"""
print(f"{Fore.YELLOW}[1/2] 加载 PyTorch 模型...")
if not os.path.exists(self.pth_path):
raise FileNotFoundError(f"PTH模型不存在: {self.pth_path}")
# 加载checkpoint并检测hidden_dim
checkpoint = torch.load(self.pth_path, map_location=self.device, weights_only=False)
if 'policy_state_dict' in checkpoint:
state_dict = checkpoint['policy_state_dict']
else:
state_dict = checkpoint
# 自动检测hidden_dim
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
# 创建模型
self.pth_model = PPONetwork(state_dim=self.state_dim, hidden_dim=hidden_dim)
# 加载权重(兼容旧模型)
if 'policy_state_dict' in checkpoint:
try:
self.pth_model.load_state_dict(checkpoint['policy_state_dict'])
except RuntimeError as e:
if '_freqs' in str(e):
self.pth_model.load_state_dict(checkpoint['policy_state_dict'], strict=False)
else:
raise
else:
try:
self.pth_model.load_state_dict(checkpoint)
except RuntimeError as e:
if '_freqs' in str(e):
self.pth_model.load_state_dict(checkpoint, strict=False)
else:
raise
self.pth_model.eval()
print(f"{Fore.GREEN} [OK] PyTorch 模型加载成功")
print(f" 路径: {self.pth_path}")
print(f" Hidden Dim: {hidden_dim}")
if 'total_steps' in checkpoint:
print(f" 训练步数: {checkpoint['total_steps']}")
if 'total_episodes' in checkpoint:
print(f" 训练轮数: {checkpoint['total_episodes']}")
print()
def _load_onnx_model(self):
"""加载ONNX模型"""
print(f"{Fore.YELLOW}[2/2] 加载 ONNX 模型...")
if not os.path.exists(self.onnx_path):
raise FileNotFoundError(f"ONNX模型不存在: {self.onnx_path}")
# 创建ONNX推理会话
self.onnx_session = ort.InferenceSession(self.onnx_path)
print(f"{Fore.GREEN} [OK] ONNX 模型加载成功")
print(f" 路径: {self.onnx_path}")
# 打印输入输出信息
print(f" 输入:")
for input_meta in self.onnx_session.get_inputs():
print(f" - {input_meta.name}: {input_meta.shape} ({input_meta.type})")
print(f" 输出:")
for output_meta in self.onnx_session.get_outputs():
print(f" - {output_meta.name}: {output_meta.shape} ({output_meta.type})")
print()
def predict_pth(self, state, valid_actions):
"""
使用PyTorch模型预测
Args:
state: (state_dim,) numpy array
valid_actions: (n_actions,) numpy array of uint64
Returns:
action_probs: (n_actions,) numpy array概率分布
selected_action: uint64选择的动作
"""
with torch.no_grad():
# 转换为tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_ids_tensor = torch.from_numpy(valid_actions).to(self.device)
# 获取动作logits与ONNX导出逻辑一致
shared_features, _ = self.pth_model.forward(state_tensor)
action_logits = self.pth_model.get_action_logits(shared_features, action_ids_tensor)
action_logits = action_logits.cpu().numpy()
# 确保是 1D 数组
if action_logits.ndim == 2:
action_logits = action_logits.squeeze(0)
# 转换为概率使用与ONNX相同的numpy softmax
logits_max = np.max(action_logits)
exp_logits = np.exp(action_logits - logits_max)
action_probs = exp_logits / np.sum(exp_logits)
# 选择概率最大的动作deterministic
action_index = np.argmax(action_probs)
selected_action = valid_actions[action_index]
return action_probs, selected_action
def predict_onnx(self, state, valid_actions):
"""
使用ONNX模型预测
Args:
state: (state_dim,) numpy array
valid_actions: (n_actions,) numpy array of uint64
Returns:
action_probs: (n_actions,) numpy array概率分布
selected_action: uint64选择的动作
"""
# 准备输入
state_input = state.reshape(1, -1).astype(np.float32)
actions_input = valid_actions.astype(np.int64)
# 运行ONNX推理输出是logits
outputs = self.onnx_session.run(
None,
{
'state': state_input,
'valid_actions': actions_input
}
)
action_logits = outputs[0]
# 转换为概率softmax- 使用数值稳定的方式
logits_max = np.max(action_logits)
exp_logits = np.exp(action_logits - logits_max)
action_probs = exp_logits / np.sum(exp_logits)
# 选择概率最大的动作
action_index = np.argmax(action_probs)
selected_action = valid_actions[action_index]
return action_probs, selected_action
def check_probability_distribution(self, probs, threshold=0.05):
"""
检查概率分布是否合理
Args:
probs: 概率分布
threshold: 判断"太平均"的阈值
Returns:
is_valid: bool是否合理
issues: list问题列表
"""
issues = []
# 1. 检查概率和是否接近1
prob_sum = np.sum(probs)
if not np.isclose(prob_sum, 1.0, atol=1e-5):
issues.append(f"概率和不为1: {prob_sum:.6f}")
# 2. 检查是否有NaN或Inf
if np.any(np.isnan(probs)):
issues.append("存在NaN值")
if np.any(np.isinf(probs)):
issues.append("存在Inf值")
# 3. 检查是否所有概率都接近0除了一个
zero_count = np.sum(probs < 1e-6)
if zero_count == len(probs) - 1:
issues.append("几乎所有概率都为0过度自信")
# 4. 检查是否太平均(熵太高)
n = len(probs)
if n > 1:
# 计算与均匀分布的差异
uniform_prob = 1.0 / n
max_deviation = np.max(np.abs(probs - uniform_prob))
if max_deviation < threshold:
issues.append(f"概率分布过于平均(最大偏差: {max_deviation:.4f}")
# 5. 检查最大概率
max_prob = np.max(probs)
if max_prob < 0.1 and n > 5:
issues.append(f"最大概率过低: {max_prob:.4f}")
is_valid = len(issues) == 0
return is_valid, issues
def test_single_sample(self, state, valid_actions, expected_action):
"""
测试单个样本
Args:
state: (state_dim,) numpy array
valid_actions: (n_actions,) numpy array
expected_action: uint64期望的动作
Returns:
result: dict测试结果
"""
result = {
'pth_correct': False,
'onnx_correct': False,
'models_match': False,
'pth_prob_valid': False,
'onnx_prob_valid': False,
'pth_probs': None,
'onnx_probs': None,
'pth_action': None,
'onnx_action': None,
'expected_action': expected_action,
'issues': []
}
try:
# PTH模型预测
pth_probs, pth_action = self.predict_pth(state, valid_actions)
result['pth_probs'] = pth_probs
result['pth_action'] = pth_action
result['pth_correct'] = (pth_action == expected_action)
# 检查PTH概率分布
pth_valid, pth_issues = self.check_probability_distribution(pth_probs)
result['pth_prob_valid'] = pth_valid
if not pth_valid:
result['issues'].extend([f"PTH: {issue}" for issue in pth_issues])
# ONNX模型预测
onnx_probs, onnx_action = self.predict_onnx(state, valid_actions)
result['onnx_probs'] = onnx_probs
result['onnx_action'] = onnx_action
result['onnx_correct'] = (onnx_action == expected_action)
# 检查ONNX概率分布
onnx_valid, onnx_issues = self.check_probability_distribution(onnx_probs)
result['onnx_prob_valid'] = onnx_valid
if not onnx_valid:
result['issues'].extend([f"ONNX: {issue}" for issue in onnx_issues])
# 检查两个模型是否一致
# 1. 动作是否一致
result['models_match'] = (pth_action == onnx_action)
# 2. 概率分布是否接近
prob_diff = np.abs(pth_probs - onnx_probs)
max_prob_diff = np.max(prob_diff)
result['max_prob_diff'] = max_prob_diff
# 允许5%的概率差异(由于浮点精度和实现差异)
if max_prob_diff > 0.05:
result['issues'].append(f"PTH和ONNX概率差异过大: {max_prob_diff:.6f}")
except Exception as e:
result['issues'].append(f"预测异常: {str(e)}")
return result
def test_from_jsonl(self, jsonl_path, max_samples=None):
"""
从JSONL文件测试
Args:
jsonl_path: JSONL文件路径
max_samples: 最多测试样本数None表示全部
Returns:
results: list of dict所有测试结果
"""
print(f"{Fore.CYAN}{'='*80}")
print(f"{Fore.CYAN}从 JSONL 文件测试: {os.path.basename(jsonl_path)}")
print(f"{Fore.CYAN}{'='*80}\n")
if not os.path.exists(jsonl_path):
raise FileNotFoundError(f"JSONL文件不存在: {jsonl_path}")
# 加载数据
loader = TrainingDataLoader()
episode_data = loader.load_episode_file(jsonl_path)
if len(episode_data) == 0:
print(f"{Fore.RED}错误: 无法加载数据")
return []
if max_samples is not None:
episode_data = episode_data[:max_samples]
print(f"加载了 {len(episode_data)} 个样本\n")
# 测试每个样本
results = []
for i, step_data in enumerate(episode_data):
state = np.array(step_data['State'], dtype=np.float32)
valid_actions = np.array(step_data['Actions'], dtype=np.uint64)
expected_action = int(step_data['SelectedAction'])
result = self.test_single_sample(state, valid_actions, expected_action)
results.append(result)
# 打印进度
if (i + 1) % 10 == 0:
print(f" 已测试 {i + 1}/{len(episode_data)} 个样本...")
print(f"\n{Fore.GREEN}测试完成!\n")
return results
def print_summary(self, results, skip_action_check=False):
"""
打印测试结果摘要
Args:
results: 测试结果列表
skip_action_check: 是否跳过动作检查
"""
if len(results) == 0:
print(f"{Fore.RED}无测试结果")
return
total = len(results)
pth_correct = sum(1 for r in results if r['pth_correct'])
onnx_correct = sum(1 for r in results if r['onnx_correct'])
models_match = sum(1 for r in results if r['models_match'])
pth_prob_valid = sum(1 for r in results if r['pth_prob_valid'])
onnx_prob_valid = sum(1 for r in results if r['onnx_prob_valid'])
# 统计完全通过的样本(所有条件都满足)
if skip_action_check:
# 跳过动作检查时,只要求模型一致性和概率合理
all_pass = sum(1 for r in results if (
r['models_match'] and
r['pth_prob_valid'] and
r['onnx_prob_valid'] and
len(r['issues']) == 0
))
else:
all_pass = sum(1 for r in results if (
r['pth_correct'] and
r['onnx_correct'] and
r['models_match'] and
r['pth_prob_valid'] and
r['onnx_prob_valid']
))
print(f"{Fore.CYAN}{'='*80}")
print(f"{Fore.CYAN}测试结果摘要")
print(f"{Fore.CYAN}{'='*80}\n")
print(f"总样本数: {total}\n")
# 动作正确率(如果不跳过)
if not skip_action_check:
print(f"{Fore.YELLOW}【动作选择(与训练数据对比)】")
self._print_stat("PTH模型正确率", pth_correct, total)
self._print_stat("ONNX模型正确率", onnx_correct, total)
print(f"{Fore.CYAN} 注意: 模型已通过PPO学习可能选择与行为树不同的动作{Style.RESET_ALL}")
print()
# 模型一致性
print(f"{Fore.YELLOW}【模型一致性】")
self._print_stat("PTH和ONNX输出一致", models_match, total)
# 计算平均概率差异
avg_prob_diff = np.mean([r.get('max_prob_diff', 0) for r in results])
print(f" 平均概率差异: {avg_prob_diff:.6f}")
print()
# 概率分布合理性
print(f"{Fore.YELLOW}【概率分布合理性】")
self._print_stat("PTH概率分布合理", pth_prob_valid, total)
self._print_stat("ONNX概率分布合理", onnx_prob_valid, total)
print()
# 总体通过率
print(f"{Fore.YELLOW}【综合评估】")
self._print_stat("所有检测项通过", all_pass, total, critical=True)
print()
# 列出问题样本
problem_samples = [i for i, r in enumerate(results) if len(r['issues']) > 0]
if len(problem_samples) > 0:
print(f"{Fore.RED}【问题样本】")
print(f" 发现 {len(problem_samples)} 个问题样本\n")
# 只显示前10个
show_count = min(10, len(problem_samples))
for idx in problem_samples[:show_count]:
r = results[idx]
print(f" 样本 #{idx + 1}:")
print(f" 期望动作: {r['expected_action']}")
print(f" PTH预测: {r['pth_action']} {'[OK]' if r['pth_correct'] else '[X]'}")
print(f" ONNX预测: {r['onnx_action']} {'[OK]' if r['onnx_correct'] else '[X]'}")
for issue in r['issues']:
print(f"{issue}")
print()
if len(problem_samples) > show_count:
print(f" ... 还有 {len(problem_samples) - show_count} 个问题样本未显示\n")
# 判断是否通过测试
print(f"{Fore.CYAN}{'='*80}")
# 通过标准:
if skip_action_check:
# 跳过动作检查模式:只要求模型一致性和概率合理
# 1. PTH和ONNX模型一致性至少95%
# 2. 概率分布合理性至少95%
consistency_rate = models_match / total
prob_valid_rate = min(pth_prob_valid, onnx_prob_valid) / total
if consistency_rate >= 0.95 and prob_valid_rate >= 0.95:
print(f"{Fore.GREEN}{'[OK] 测试通过!':^80}")
print(f"{Fore.GREEN}{'PTH和ONNX模型输出一致概率分布合理':^80}")
print(f"{Fore.GREEN}{'模型可以投入使用':^80}")
else:
print(f"{Fore.RED}{'[FAIL] 测试失败!':^80}")
if consistency_rate < 0.95:
print(f"{Fore.RED}{'PTH和ONNX一致性不足95%':^80}")
if prob_valid_rate < 0.95:
print(f"{Fore.RED}{'概率分布合理性不足95%':^80}")
print(f"{Fore.RED}{'建议检查模型训练或导出过程':^80}")
else:
# 完整检查模式:要求动作匹配
# 1. 至少80%的样本所有检测项都通过
# 2. PTH和ONNX模型一致性至少95%
pass_rate = all_pass / total
consistency_rate = models_match / total
if pass_rate >= 0.8 and consistency_rate >= 0.95:
print(f"{Fore.GREEN}{'[OK] 测试通过!':^80}")
print(f"{Fore.GREEN}{'所有检测项达标,模型可以投入使用':^80}")
else:
print(f"{Fore.RED}{'[FAIL] 测试失败!':^80}")
if pass_rate < 0.8:
print(f"{Fore.RED}{'综合通过率不足80%':^80}")
if consistency_rate < 0.95:
print(f"{Fore.RED}{'PTH和ONNX一致性不足95%':^80}")
print(f"{Fore.RED}{'建议检查模型训练或导出过程':^80}")
print(f"{Fore.CYAN}{'='*80}\n")
def _print_stat(self, label, count, total, critical=False):
"""打印统计信息"""
percentage = (count / total) * 100
# 根据百分比选择颜色
if percentage >= 95:
color = Fore.GREEN
elif percentage >= 80:
color = Fore.YELLOW
else:
color = Fore.RED
if critical and percentage >= 80:
color = Fore.GREEN
bar_length = 40
filled = int(bar_length * count / total)
bar = '' * filled + '' * (bar_length - filled)
print(f" {label:30} {color}{count:4d}/{total:4d} ({percentage:5.1f}%) {bar}{Style.RESET_ALL}")
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='测试PTH和ONNX模型的一致性')
parser.add_argument('--pth', type=str, default='models/ppo_model_best.pth',
help='PyTorch模型路径')
parser.add_argument('--onnx', type=str, default='models/ppo_model_best.onnx',
help='ONNX模型路径')
parser.add_argument('--data', type=str, default=None,
help='JSONL数据文件路径默认随机选择一个')
parser.add_argument('--max-samples', type=int, default=100,
help='最多测试样本数默认100')
parser.add_argument('--state-dim', type=int, default=572,
help='状态维度默认572')
parser.add_argument('--skip-action-check', action='store_true',
help='跳过动作正确性检查(只检查模型一致性)')
args = parser.parse_args()
# 如果没有指定数据文件,随机选择一个
if args.data is None:
data_dir = 'Data'
if os.path.exists(data_dir):
jsonl_files = [f for f in os.listdir(data_dir) if f.endswith('.jsonl')]
if len(jsonl_files) > 0:
# 选择中间的一个文件(避免选到开头或结尾的极端情况)
args.data = os.path.join(data_dir, jsonl_files[len(jsonl_files) // 2])
print(f"{Fore.YELLOW}自动选择数据文件: {os.path.basename(args.data)}\n")
else:
print(f"{Fore.RED}错误: Data目录中没有JSONL文件")
return
else:
print(f"{Fore.RED}错误: Data目录不存在")
return
# 检查文件是否存在
if not os.path.exists(args.pth):
print(f"{Fore.RED}错误: PTH模型不存在: {args.pth}")
return
if not os.path.exists(args.onnx):
print(f"{Fore.RED}错误: ONNX模型不存在: {args.onnx}")
print(f"{Fore.YELLOW}提示: 请先运行 export_to_onnx.py 导出ONNX模型")
return
try:
# 创建测试器
tester = ModelConsistencyTester(
pth_path=args.pth,
onnx_path=args.onnx,
state_dim=args.state_dim
)
# 运行测试
results = tester.test_from_jsonl(args.data, max_samples=args.max_samples)
# 打印结果
tester.print_summary(results, skip_action_check=args.skip_action_check)
except Exception as e:
print(f"{Fore.RED}错误: {e}")
import traceback
traceback.print_exc()
if __name__ == '__main__':
main()