TH1/AITrainPython/debug_single_sample.py
2025-12-06 11:44:43 +08:00

105 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
调试单个样本检查PTH和ONNX的详细输出
"""
import os
import json
import numpy as np
import torch
import onnxruntime as ort
from ppo_model import PPONetwork
# 加载模型
pth_path = 'models/ppo_model_best.pth'
onnx_path = 'models/ppo_model_best.onnx'
device = torch.device('cpu')
# 加载PTH
checkpoint = torch.load(pth_path, map_location=device, weights_only=False)
state_dict = checkpoint['policy_state_dict'] if 'policy_state_dict' in checkpoint else checkpoint
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
pth_model = PPONetwork(state_dim=572, hidden_dim=hidden_dim)
pth_model.load_state_dict(checkpoint['policy_state_dict'] if 'policy_state_dict' in checkpoint else checkpoint)
pth_model.eval()
# 加载ONNX
onnx_session = ort.InferenceSession(onnx_path)
# 加载一个样本
with open('Data/episode_20251205_130825_e7af75e5_325.jsonl', 'r', encoding='utf-8-sig') as f:
line = f.readline()
data = json.loads(line)
state = np.array(data['State'], dtype=np.float32)
valid_actions = np.array(data['Actions'], dtype=np.uint64)
print(f"State shape: {state.shape}")
print(f"Valid actions: {valid_actions}")
print(f"Number of valid actions: {len(valid_actions)}")
print()
# PTH预测
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
action_ids_tensor = torch.from_numpy(valid_actions)
print("=== PTH Model ===")
print(f"State tensor shape: {state_tensor.shape}")
print(f"Action IDs tensor shape: {action_ids_tensor.shape}")
print(f"Action IDs dtype: {action_ids_tensor.dtype}")
# 获取logits
shared_features, _ = pth_model.forward(state_tensor)
print(f"Shared features shape: {shared_features.shape}")
pth_logits = pth_model.get_action_logits(shared_features, action_ids_tensor)
print(f"PTH Logits shape: {pth_logits.shape}")
print(f"PTH Logits: {pth_logits.cpu().numpy()}")
# 转换为概率
pth_probs = torch.softmax(pth_logits, dim=0)
print(f"PTH Probs: {pth_probs.cpu().numpy()}")
print(f"PTH Probs sum: {pth_probs.sum().item()}")
print()
# ONNX预测
print("=== ONNX Model ===")
state_input = state.reshape(1, -1).astype(np.float32)
actions_input = valid_actions.astype(np.int64)
print(f"State input shape: {state_input.shape}")
print(f"Actions input shape: {actions_input.shape}")
print(f"Actions input dtype: {actions_input.dtype}")
outputs = onnx_session.run(
None,
{
'state': state_input,
'valid_actions': actions_input
}
)
onnx_logits = outputs[0]
print(f"ONNX Logits shape: {onnx_logits.shape}")
print(f"ONNX Logits: {onnx_logits}")
# 转换为概率
exp_logits = np.exp(onnx_logits - np.max(onnx_logits))
onnx_probs = exp_logits / exp_logits.sum()
print(f"ONNX Probs: {onnx_probs}")
print(f"ONNX Probs sum: {onnx_probs.sum()}")
print()
# 比较
print("=== Comparison ===")
logit_diff = np.abs(pth_logits.cpu().numpy() - onnx_logits)
prob_diff = np.abs(pth_probs.cpu().numpy() - onnx_probs)
print(f"Max logit difference: {np.max(logit_diff):.6f}")
print(f"Max prob difference: {np.max(prob_diff):.6f}")
print(f"\nPTH selected action: {valid_actions[np.argmax(pth_probs.cpu().numpy())]}")
print(f"ONNX selected action: {valid_actions[np.argmax(onnx_probs)]}")
print(f"Expected action: {data['SelectedAction']}")