109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
"""
|
|
逐层比较 PTH 和 ONNX 的输出,找出差异来源
|
|
"""
|
|
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
import onnxruntime as ort
|
|
from ppo_model import PPONetwork
|
|
|
|
# 加载模型
|
|
pth_path = 'models/ppo_model_best.pth'
|
|
onnx_path = 'models/ppo_model_best.onnx'
|
|
device = torch.device('cpu')
|
|
|
|
checkpoint = torch.load(pth_path, map_location=device, weights_only=False)
|
|
state_dict = checkpoint['policy_state_dict']
|
|
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
|
|
pth_model = PPONetwork(state_dim=572, hidden_dim=hidden_dim)
|
|
pth_model.load_state_dict(state_dict)
|
|
pth_model.eval()
|
|
|
|
onnx_session = ort.InferenceSession(onnx_path)
|
|
|
|
# 使用真实数据第19个样本
|
|
import json
|
|
with open('Data/episode_20251205_130825_e7af75e5_325.jsonl', 'r', encoding='utf-8-sig') as f:
|
|
lines = f.readlines()
|
|
data = json.loads(lines[18])
|
|
state = np.array(data['State'], dtype=np.float32)
|
|
valid_actions = np.array(data['Actions'], dtype=np.int64)
|
|
|
|
print("测试用例:")
|
|
print(f" state shape: {state.shape}")
|
|
print(f" valid_actions: {valid_actions}")
|
|
print()
|
|
|
|
# PTH 推理
|
|
print("="*80)
|
|
print("PTH 推理")
|
|
print("="*80)
|
|
|
|
with torch.no_grad():
|
|
state_tensor = torch.from_numpy(state).unsqueeze(0)
|
|
actions_tensor = torch.from_numpy(valid_actions)
|
|
|
|
# 步骤1: state -> shared_features
|
|
shared_features, _ = pth_model.forward(state_tensor)
|
|
print(f"1. Shared features shape: {shared_features.shape}")
|
|
print(f" 前10个值: {shared_features[0, :10]}")
|
|
print(f" 统计: mean={shared_features.mean().item():.6f}, std={shared_features.std().item():.6f}")
|
|
|
|
# 步骤2: actions -> action_features
|
|
action_features = pth_model.action_encoder.encode(actions_tensor)
|
|
print(f"\n2. Action features shape: {action_features.shape}")
|
|
print(f" 第1个action的前10个值: {action_features[0, :10]}")
|
|
print(f" 统计: mean={action_features.mean().item():.6f}, std={action_features.std().item():.6f}")
|
|
|
|
# 步骤3: 组合特征
|
|
num_actions = len(valid_actions)
|
|
expanded_state = shared_features.unsqueeze(1).expand(1, num_actions, -1)
|
|
expanded_actions = action_features.unsqueeze(0).expand(1, num_actions, -1)
|
|
combined = torch.cat([expanded_state, expanded_actions], dim=-1)
|
|
print(f"\n3. Combined features shape: {combined.shape}")
|
|
print(f" 第1个组合的前10个值: {combined[0, 0, :10]}")
|
|
|
|
# 步骤4: actor_head
|
|
flat_combined = combined.reshape(-1, combined.shape[-1])
|
|
logits = pth_model.actor_head(flat_combined)
|
|
logits = logits.reshape(1, num_actions)
|
|
print(f"\n4. Logits shape: {logits.shape}")
|
|
print(f" 值: {logits[0]}")
|
|
|
|
# 如果需要 squeeze
|
|
if logits.shape[0] == 1:
|
|
logits = logits.squeeze(0)
|
|
|
|
pth_logits = logits.numpy()
|
|
print(f"\n最终 PTH logits: {pth_logits}")
|
|
|
|
# ONNX 推理
|
|
print("\n" + "="*80)
|
|
print("ONNX 推理")
|
|
print("="*80)
|
|
|
|
state_input = state.reshape(1, -1).astype(np.float32)
|
|
actions_input = valid_actions.astype(np.int64)
|
|
|
|
outputs = onnx_session.run(None, {'state': state_input, 'valid_actions': actions_input})
|
|
onnx_logits = outputs[0]
|
|
|
|
print(f"最终 ONNX logits: {onnx_logits}")
|
|
|
|
# 比较
|
|
print("\n" + "="*80)
|
|
print("比较")
|
|
print("="*80)
|
|
|
|
diff = np.abs(pth_logits - onnx_logits)
|
|
print(f"Logits 差异: {diff}")
|
|
print(f"最大差异: {diff.max():.6f}")
|
|
print(f"平均差异: {diff.mean():.6f}")
|
|
|
|
if diff.max() < 0.001:
|
|
print("\n✅ 一致!差异在可接受范围内")
|
|
else:
|
|
print(f"\n❌ 不一致!最大差异 {diff.max():.6f} 超过阈值")
|
|
|