TH1/AITrainPython/debug_onnx.py
2025-12-06 11:44:43 +08:00

90 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import onnxruntime as ort
# 加载模型
model_path = 'models/ppo_model_epoch_20.onnx'
session = ort.InferenceSession(model_path)
# 打印模型信息
print("模型输入:")
for inp in session.get_inputs():
print(f" {inp.name}: {inp.shape} ({inp.type})")
print("\n模型输出:")
for out in session.get_outputs():
print(f" {out.name}: {out.shape} ({out.type})")
# 测试推理 - 场景1不包含actionId=0
print("\n" + "="*60)
print("场景1: 不包含 actionId=0")
print("="*60)
state = np.random.randn(1, 572).astype(np.float32)
actions = np.array([10, 25, 37, 42, 99], dtype=np.int64)
print(f"\n输入:")
print(f" state shape: {state.shape}")
print(f" actions: {actions}")
outputs = session.run(None, {'state': state, 'valid_actions': actions})
print(f"\n输出:")
print(f" output shape: {outputs[0].shape}")
print(f" output dtype: {outputs[0].dtype}")
print(f" output: {outputs[0]}")
# 计算概率
logits = outputs[0]
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / exp_logits.sum()
print(f"\n概率:")
for i, (action, prob) in enumerate(zip(actions, probs)):
print(f" [{i}] Action {action}: {prob:.6f} ({prob*100:.2f}%)")
# 检查是否第一个是100%
if probs[0] > 0.99:
print("\n❌ 警告: 第一个动作概率接近100%")
else:
print("\n✅ 概率分布正常")
# 测试推理 - 场景2包含actionId=0结束回合
print("\n" + "="*60)
print("场景2: 包含 actionId=0结束回合")
print("="*60)
actions_with_zero = np.array([0, 10, 25, 37, 42], dtype=np.int64)
print(f"\n输入:")
print(f" state shape: {state.shape}")
print(f" actions: {actions_with_zero}")
outputs2 = session.run(None, {'state': state, 'valid_actions': actions_with_zero})
print(f"\n输出:")
print(f" output shape: {outputs2[0].shape}")
print(f" output: {outputs2[0]}")
# 计算概率
logits2 = outputs2[0]
exp_logits2 = np.exp(logits2 - np.max(logits2))
probs2 = exp_logits2 / exp_logits2.sum()
print(f"\n概率:")
for i, (action, prob) in enumerate(zip(actions_with_zero, probs2)):
marker = " ← ⚠️ actionId=0" if action == 0 else ""
print(f" [{i}] Action {action}: {prob:.6f} ({prob*100:.2f}%){marker}")
# 检查actionId=0的概率
action_0_index = np.where(actions_with_zero == 0)[0][0]
action_0_prob = probs2[action_0_index]
if action_0_prob > 0.99:
print(f"\n❌ 严重问题: actionId=0 的概率是 {action_0_prob*100:.2f}% (接近100%)")
print(" 模型学习到了错误的模式:总是选择结束回合!")
elif action_0_prob > 0.5:
print(f"\n⚠️ 警告: actionId=0 的概率是 {action_0_prob*100:.2f}% (超过50%)")
print(" 模型可能过度偏好结束回合。")
else:
print(f"\n✅ actionId=0 的概率是 {action_0_prob*100:.2f}%,在合理范围内。")