TH1/AITrainPython/compare_layers.py
2025-12-06 11:44:43 +08:00

109 lines
3.4 KiB
Python

"""
逐层比较 PTH 和 ONNX 的输出,找出差异来源
"""
import json
import numpy as np
import torch
import onnxruntime as ort
from ppo_model import PPONetwork
# 加载模型
pth_path = 'models/ppo_model_best.pth'
onnx_path = 'models/ppo_model_best.onnx'
device = torch.device('cpu')
checkpoint = torch.load(pth_path, map_location=device, weights_only=False)
state_dict = checkpoint['policy_state_dict']
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
pth_model = PPONetwork(state_dim=572, hidden_dim=hidden_dim)
pth_model.load_state_dict(state_dict)
pth_model.eval()
onnx_session = ort.InferenceSession(onnx_path)
# 使用真实数据第19个样本
import json
with open('Data/episode_20251205_130825_e7af75e5_325.jsonl', 'r', encoding='utf-8-sig') as f:
lines = f.readlines()
data = json.loads(lines[18])
state = np.array(data['State'], dtype=np.float32)
valid_actions = np.array(data['Actions'], dtype=np.int64)
print("测试用例:")
print(f" state shape: {state.shape}")
print(f" valid_actions: {valid_actions}")
print()
# PTH 推理
print("="*80)
print("PTH 推理")
print("="*80)
with torch.no_grad():
state_tensor = torch.from_numpy(state).unsqueeze(0)
actions_tensor = torch.from_numpy(valid_actions)
# 步骤1: state -> shared_features
shared_features, _ = pth_model.forward(state_tensor)
print(f"1. Shared features shape: {shared_features.shape}")
print(f" 前10个值: {shared_features[0, :10]}")
print(f" 统计: mean={shared_features.mean().item():.6f}, std={shared_features.std().item():.6f}")
# 步骤2: actions -> action_features
action_features = pth_model.action_encoder.encode(actions_tensor)
print(f"\n2. Action features shape: {action_features.shape}")
print(f" 第1个action的前10个值: {action_features[0, :10]}")
print(f" 统计: mean={action_features.mean().item():.6f}, std={action_features.std().item():.6f}")
# 步骤3: 组合特征
num_actions = len(valid_actions)
expanded_state = shared_features.unsqueeze(1).expand(1, num_actions, -1)
expanded_actions = action_features.unsqueeze(0).expand(1, num_actions, -1)
combined = torch.cat([expanded_state, expanded_actions], dim=-1)
print(f"\n3. Combined features shape: {combined.shape}")
print(f" 第1个组合的前10个值: {combined[0, 0, :10]}")
# 步骤4: actor_head
flat_combined = combined.reshape(-1, combined.shape[-1])
logits = pth_model.actor_head(flat_combined)
logits = logits.reshape(1, num_actions)
print(f"\n4. Logits shape: {logits.shape}")
print(f" 值: {logits[0]}")
# 如果需要 squeeze
if logits.shape[0] == 1:
logits = logits.squeeze(0)
pth_logits = logits.numpy()
print(f"\n最终 PTH logits: {pth_logits}")
# ONNX 推理
print("\n" + "="*80)
print("ONNX 推理")
print("="*80)
state_input = state.reshape(1, -1).astype(np.float32)
actions_input = valid_actions.astype(np.int64)
outputs = onnx_session.run(None, {'state': state_input, 'valid_actions': actions_input})
onnx_logits = outputs[0]
print(f"最终 ONNX logits: {onnx_logits}")
# 比较
print("\n" + "="*80)
print("比较")
print("="*80)
diff = np.abs(pth_logits - onnx_logits)
print(f"Logits 差异: {diff}")
print(f"最大差异: {diff.max():.6f}")
print(f"平均差异: {diff.mean():.6f}")
if diff.max() < 0.001:
print("\n✅ 一致!差异在可接受范围内")
else:
print(f"\n❌ 不一致!最大差异 {diff.max():.6f} 超过阈值")