105 lines
3.2 KiB
Python
105 lines
3.2 KiB
Python
"""
|
||
调试单个样本,检查PTH和ONNX的详细输出
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import numpy as np
|
||
import torch
|
||
import onnxruntime as ort
|
||
from ppo_model import PPONetwork
|
||
|
||
# 加载模型
|
||
pth_path = 'models/ppo_model_best.pth'
|
||
onnx_path = 'models/ppo_model_best.onnx'
|
||
device = torch.device('cpu')
|
||
|
||
# 加载PTH
|
||
checkpoint = torch.load(pth_path, map_location=device, weights_only=False)
|
||
state_dict = checkpoint['policy_state_dict'] if 'policy_state_dict' in checkpoint else checkpoint
|
||
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
|
||
pth_model = PPONetwork(state_dim=572, hidden_dim=hidden_dim)
|
||
pth_model.load_state_dict(checkpoint['policy_state_dict'] if 'policy_state_dict' in checkpoint else checkpoint)
|
||
pth_model.eval()
|
||
|
||
# 加载ONNX
|
||
onnx_session = ort.InferenceSession(onnx_path)
|
||
|
||
# 加载一个样本
|
||
with open('Data/episode_20251205_130825_e7af75e5_325.jsonl', 'r', encoding='utf-8-sig') as f:
|
||
line = f.readline()
|
||
data = json.loads(line)
|
||
|
||
state = np.array(data['State'], dtype=np.float32)
|
||
valid_actions = np.array(data['Actions'], dtype=np.uint64)
|
||
|
||
print(f"State shape: {state.shape}")
|
||
print(f"Valid actions: {valid_actions}")
|
||
print(f"Number of valid actions: {len(valid_actions)}")
|
||
print()
|
||
|
||
# PTH预测
|
||
with torch.no_grad():
|
||
state_tensor = torch.FloatTensor(state).unsqueeze(0)
|
||
action_ids_tensor = torch.from_numpy(valid_actions)
|
||
|
||
print("=== PTH Model ===")
|
||
print(f"State tensor shape: {state_tensor.shape}")
|
||
print(f"Action IDs tensor shape: {action_ids_tensor.shape}")
|
||
print(f"Action IDs dtype: {action_ids_tensor.dtype}")
|
||
|
||
# 获取logits
|
||
shared_features, _ = pth_model.forward(state_tensor)
|
||
print(f"Shared features shape: {shared_features.shape}")
|
||
|
||
pth_logits = pth_model.get_action_logits(shared_features, action_ids_tensor)
|
||
print(f"PTH Logits shape: {pth_logits.shape}")
|
||
print(f"PTH Logits: {pth_logits.cpu().numpy()}")
|
||
|
||
# 转换为概率
|
||
pth_probs = torch.softmax(pth_logits, dim=0)
|
||
print(f"PTH Probs: {pth_probs.cpu().numpy()}")
|
||
print(f"PTH Probs sum: {pth_probs.sum().item()}")
|
||
print()
|
||
|
||
# ONNX预测
|
||
print("=== ONNX Model ===")
|
||
state_input = state.reshape(1, -1).astype(np.float32)
|
||
actions_input = valid_actions.astype(np.int64)
|
||
|
||
print(f"State input shape: {state_input.shape}")
|
||
print(f"Actions input shape: {actions_input.shape}")
|
||
print(f"Actions input dtype: {actions_input.dtype}")
|
||
|
||
outputs = onnx_session.run(
|
||
None,
|
||
{
|
||
'state': state_input,
|
||
'valid_actions': actions_input
|
||
}
|
||
)
|
||
|
||
onnx_logits = outputs[0]
|
||
print(f"ONNX Logits shape: {onnx_logits.shape}")
|
||
print(f"ONNX Logits: {onnx_logits}")
|
||
|
||
# 转换为概率
|
||
exp_logits = np.exp(onnx_logits - np.max(onnx_logits))
|
||
onnx_probs = exp_logits / exp_logits.sum()
|
||
print(f"ONNX Probs: {onnx_probs}")
|
||
print(f"ONNX Probs sum: {onnx_probs.sum()}")
|
||
print()
|
||
|
||
# 比较
|
||
print("=== Comparison ===")
|
||
logit_diff = np.abs(pth_logits.cpu().numpy() - onnx_logits)
|
||
prob_diff = np.abs(pth_probs.cpu().numpy() - onnx_probs)
|
||
|
||
print(f"Max logit difference: {np.max(logit_diff):.6f}")
|
||
print(f"Max prob difference: {np.max(prob_diff):.6f}")
|
||
|
||
print(f"\nPTH selected action: {valid_actions[np.argmax(pth_probs.cpu().numpy())]}")
|
||
print(f"ONNX selected action: {valid_actions[np.argmax(onnx_probs)]}")
|
||
print(f"Expected action: {data['SelectedAction']}")
|
||
|