110 lines
4.2 KiB
Python
110 lines
4.2 KiB
Python
"""
|
||
模拟C#端的问题:检查为什么ActionId=0总是100%概率
|
||
"""
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
|
||
def test_with_actual_data():
|
||
"""使用实际的游戏数据测试"""
|
||
print("="*80)
|
||
print("测试场景:使用接近实际游戏的数据")
|
||
print("="*80)
|
||
|
||
# 加载模型
|
||
model_path = 'models/ppo_model_epoch_20.onnx'
|
||
session = ort.InferenceSession(model_path)
|
||
|
||
# 场景1:State全为0(异常情况)
|
||
print("\n场景1: State全为0(异常情况)")
|
||
print("-"*80)
|
||
state_zeros = np.zeros((1, 572), dtype=np.float32)
|
||
valid_actions = np.array([3624998400, 3692107264, 3759216128, 4832957952, 1115654, 1115655, 1115656, 1115657, 0], dtype=np.int64)
|
||
|
||
print(f"State: 全0,shape={state_zeros.shape}")
|
||
print(f"ValidActions: {valid_actions}")
|
||
print(f"ActionId=0 at index: {np.where(valid_actions == 0)[0][0]}")
|
||
|
||
outputs = session.run(None, {'state': state_zeros, 'valid_actions': valid_actions})
|
||
logits = outputs[0]
|
||
|
||
print(f"\nLogits: {logits}")
|
||
|
||
# Softmax
|
||
exp_logits = np.exp(logits - np.max(logits))
|
||
probs = exp_logits / exp_logits.sum()
|
||
|
||
print(f"\nProbs:")
|
||
for i, (action_id, prob) in enumerate(zip(valid_actions, probs)):
|
||
marker = " ← ActionId=0" if action_id == 0 else ""
|
||
print(f" [{i}] ActionId {action_id}: {prob:.6f} ({prob*100:.2f}%){marker}")
|
||
|
||
action_0_index = np.where(valid_actions == 0)[0][0]
|
||
if probs[action_0_index] > 0.99:
|
||
print(f"\n❌ 复现问题!ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
|
||
else:
|
||
print(f"\n✅ 未复现问题。ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
|
||
|
||
# 场景2:State随机但很小的值
|
||
print("\n" + "="*80)
|
||
print("场景2: State为很小的随机值")
|
||
print("-"*80)
|
||
state_small = np.random.randn(1, 572).astype(np.float32) * 0.01
|
||
|
||
print(f"State: 随机小值,mean={state_small.mean():.6f}, std={state_small.std():.6f}")
|
||
|
||
outputs = session.run(None, {'state': state_small, 'valid_actions': valid_actions})
|
||
logits = outputs[0]
|
||
|
||
print(f"\nLogits: {logits}")
|
||
|
||
exp_logits = np.exp(logits - np.max(logits))
|
||
probs = exp_logits / exp_logits.sum()
|
||
|
||
print(f"\nProbs:")
|
||
for i, (action_id, prob) in enumerate(zip(valid_actions, probs)):
|
||
marker = " ← ActionId=0" if action_id == 0 else ""
|
||
print(f" [{i}] ActionId {action_id}: {prob:.6f} ({prob*100:.2f}%){marker}")
|
||
|
||
if probs[action_0_index] > 0.99:
|
||
print(f"\n❌ 复现问题!ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
|
||
else:
|
||
print(f"\n✅ 未复现问题。ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
|
||
|
||
# 场景3:ValidActions的顺序影响
|
||
print("\n" + "="*80)
|
||
print("场景3: 改变ValidActions的顺序")
|
||
print("-"*80)
|
||
|
||
# 将ActionId=0放在第一个
|
||
valid_actions_reordered = np.array([0, 3624998400, 3692107264, 3759216128, 4832957952, 1115654, 1115655, 1115656, 1115657], dtype=np.int64)
|
||
|
||
state_normal = np.random.randn(1, 572).astype(np.float32)
|
||
|
||
print(f"ValidActions (重新排序): {valid_actions_reordered}")
|
||
print(f"ActionId=0 at index: {np.where(valid_actions_reordered == 0)[0][0]}")
|
||
|
||
outputs = session.run(None, {'state': state_normal, 'valid_actions': valid_actions_reordered})
|
||
logits = outputs[0]
|
||
|
||
print(f"\nLogits: {logits}")
|
||
|
||
exp_logits = np.exp(logits - np.max(logits))
|
||
probs = exp_logits / exp_logits.sum()
|
||
|
||
print(f"\nProbs:")
|
||
for i, (action_id, prob) in enumerate(zip(valid_actions_reordered, probs)):
|
||
marker = " ← ActionId=0" if action_id == 0 else ""
|
||
print(f" [{i}] ActionId {action_id}: {prob:.6f} ({prob*100:.2f}%){marker}")
|
||
|
||
action_0_index_new = np.where(valid_actions_reordered == 0)[0][0]
|
||
if probs[action_0_index_new] > 0.99:
|
||
print(f"\n❌ ActionId=0的概率仍然是{probs[action_0_index_new]*100:.2f}%")
|
||
print(" 问题与ValidActions的顺序无关!")
|
||
else:
|
||
print(f"\n✅ ActionId=0的概率是{probs[action_0_index_new]*100:.2f}%")
|
||
print(" 改变顺序后问题消失了!可能与训练时的顺序有关。")
|
||
|
||
if __name__ == "__main__":
|
||
test_with_actual_data()
|
||
|