TH1/AITrainPython/test_actionid0_problem.py
2025-12-06 11:44:43 +08:00

110 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模拟C#端的问题检查为什么ActionId=0总是100%概率
"""
import numpy as np
import onnxruntime as ort
def test_with_actual_data():
"""使用实际的游戏数据测试"""
print("="*80)
print("测试场景:使用接近实际游戏的数据")
print("="*80)
# 加载模型
model_path = 'models/ppo_model_epoch_20.onnx'
session = ort.InferenceSession(model_path)
# 场景1State全为0异常情况
print("\n场景1: State全为0异常情况")
print("-"*80)
state_zeros = np.zeros((1, 572), dtype=np.float32)
valid_actions = np.array([3624998400, 3692107264, 3759216128, 4832957952, 1115654, 1115655, 1115656, 1115657, 0], dtype=np.int64)
print(f"State: 全0shape={state_zeros.shape}")
print(f"ValidActions: {valid_actions}")
print(f"ActionId=0 at index: {np.where(valid_actions == 0)[0][0]}")
outputs = session.run(None, {'state': state_zeros, 'valid_actions': valid_actions})
logits = outputs[0]
print(f"\nLogits: {logits}")
# Softmax
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / exp_logits.sum()
print(f"\nProbs:")
for i, (action_id, prob) in enumerate(zip(valid_actions, probs)):
marker = " ← ActionId=0" if action_id == 0 else ""
print(f" [{i}] ActionId {action_id}: {prob:.6f} ({prob*100:.2f}%){marker}")
action_0_index = np.where(valid_actions == 0)[0][0]
if probs[action_0_index] > 0.99:
print(f"\n❌ 复现问题ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
else:
print(f"\n✅ 未复现问题。ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
# 场景2State随机但很小的值
print("\n" + "="*80)
print("场景2: State为很小的随机值")
print("-"*80)
state_small = np.random.randn(1, 572).astype(np.float32) * 0.01
print(f"State: 随机小值mean={state_small.mean():.6f}, std={state_small.std():.6f}")
outputs = session.run(None, {'state': state_small, 'valid_actions': valid_actions})
logits = outputs[0]
print(f"\nLogits: {logits}")
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / exp_logits.sum()
print(f"\nProbs:")
for i, (action_id, prob) in enumerate(zip(valid_actions, probs)):
marker = " ← ActionId=0" if action_id == 0 else ""
print(f" [{i}] ActionId {action_id}: {prob:.6f} ({prob*100:.2f}%){marker}")
if probs[action_0_index] > 0.99:
print(f"\n❌ 复现问题ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
else:
print(f"\n✅ 未复现问题。ActionId=0的概率是{probs[action_0_index]*100:.2f}%")
# 场景3ValidActions的顺序影响
print("\n" + "="*80)
print("场景3: 改变ValidActions的顺序")
print("-"*80)
# 将ActionId=0放在第一个
valid_actions_reordered = np.array([0, 3624998400, 3692107264, 3759216128, 4832957952, 1115654, 1115655, 1115656, 1115657], dtype=np.int64)
state_normal = np.random.randn(1, 572).astype(np.float32)
print(f"ValidActions (重新排序): {valid_actions_reordered}")
print(f"ActionId=0 at index: {np.where(valid_actions_reordered == 0)[0][0]}")
outputs = session.run(None, {'state': state_normal, 'valid_actions': valid_actions_reordered})
logits = outputs[0]
print(f"\nLogits: {logits}")
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / exp_logits.sum()
print(f"\nProbs:")
for i, (action_id, prob) in enumerate(zip(valid_actions_reordered, probs)):
marker = " ← ActionId=0" if action_id == 0 else ""
print(f" [{i}] ActionId {action_id}: {prob:.6f} ({prob*100:.2f}%){marker}")
action_0_index_new = np.where(valid_actions_reordered == 0)[0][0]
if probs[action_0_index_new] > 0.99:
print(f"\n❌ ActionId=0的概率仍然是{probs[action_0_index_new]*100:.2f}%")
print(" 问题与ValidActions的顺序无关")
else:
print(f"\n✅ ActionId=0的概率是{probs[action_0_index_new]*100:.2f}%")
print(" 改变顺序后问题消失了!可能与训练时的顺序有关。")
if __name__ == "__main__":
test_with_actual_data()