90 lines
2.8 KiB
Python
90 lines
2.8 KiB
Python
import numpy as np
|
||
import onnxruntime as ort
|
||
|
||
# 加载模型
|
||
model_path = 'models/ppo_model_epoch_20.onnx'
|
||
session = ort.InferenceSession(model_path)
|
||
|
||
# 打印模型信息
|
||
print("模型输入:")
|
||
for inp in session.get_inputs():
|
||
print(f" {inp.name}: {inp.shape} ({inp.type})")
|
||
|
||
print("\n模型输出:")
|
||
for out in session.get_outputs():
|
||
print(f" {out.name}: {out.shape} ({out.type})")
|
||
|
||
# 测试推理 - 场景1:不包含actionId=0
|
||
print("\n" + "="*60)
|
||
print("场景1: 不包含 actionId=0")
|
||
print("="*60)
|
||
state = np.random.randn(1, 572).astype(np.float32)
|
||
actions = np.array([10, 25, 37, 42, 99], dtype=np.int64)
|
||
|
||
print(f"\n输入:")
|
||
print(f" state shape: {state.shape}")
|
||
print(f" actions: {actions}")
|
||
|
||
outputs = session.run(None, {'state': state, 'valid_actions': actions})
|
||
|
||
print(f"\n输出:")
|
||
print(f" output shape: {outputs[0].shape}")
|
||
print(f" output dtype: {outputs[0].dtype}")
|
||
print(f" output: {outputs[0]}")
|
||
|
||
# 计算概率
|
||
logits = outputs[0]
|
||
exp_logits = np.exp(logits - np.max(logits))
|
||
probs = exp_logits / exp_logits.sum()
|
||
|
||
print(f"\n概率:")
|
||
for i, (action, prob) in enumerate(zip(actions, probs)):
|
||
print(f" [{i}] Action {action}: {prob:.6f} ({prob*100:.2f}%)")
|
||
|
||
# 检查是否第一个是100%
|
||
if probs[0] > 0.99:
|
||
print("\n❌ 警告: 第一个动作概率接近100%!")
|
||
else:
|
||
print("\n✅ 概率分布正常")
|
||
|
||
# 测试推理 - 场景2:包含actionId=0(结束回合)
|
||
print("\n" + "="*60)
|
||
print("场景2: 包含 actionId=0(结束回合)")
|
||
print("="*60)
|
||
|
||
actions_with_zero = np.array([0, 10, 25, 37, 42], dtype=np.int64)
|
||
|
||
print(f"\n输入:")
|
||
print(f" state shape: {state.shape}")
|
||
print(f" actions: {actions_with_zero}")
|
||
|
||
outputs2 = session.run(None, {'state': state, 'valid_actions': actions_with_zero})
|
||
|
||
print(f"\n输出:")
|
||
print(f" output shape: {outputs2[0].shape}")
|
||
print(f" output: {outputs2[0]}")
|
||
|
||
# 计算概率
|
||
logits2 = outputs2[0]
|
||
exp_logits2 = np.exp(logits2 - np.max(logits2))
|
||
probs2 = exp_logits2 / exp_logits2.sum()
|
||
|
||
print(f"\n概率:")
|
||
for i, (action, prob) in enumerate(zip(actions_with_zero, probs2)):
|
||
marker = " ← ⚠️ actionId=0" if action == 0 else ""
|
||
print(f" [{i}] Action {action}: {prob:.6f} ({prob*100:.2f}%){marker}")
|
||
|
||
# 检查actionId=0的概率
|
||
action_0_index = np.where(actions_with_zero == 0)[0][0]
|
||
action_0_prob = probs2[action_0_index]
|
||
|
||
if action_0_prob > 0.99:
|
||
print(f"\n❌ 严重问题: actionId=0 的概率是 {action_0_prob*100:.2f}% (接近100%)!")
|
||
print(" 模型学习到了错误的模式:总是选择结束回合!")
|
||
elif action_0_prob > 0.5:
|
||
print(f"\n⚠️ 警告: actionId=0 的概率是 {action_0_prob*100:.2f}% (超过50%)!")
|
||
print(" 模型可能过度偏好结束回合。")
|
||
else:
|
||
print(f"\n✅ actionId=0 的概率是 {action_0_prob*100:.2f}%,在合理范围内。")
|
||
|