58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
"""
|
||
检查具体的 action encoding 输出是否一致
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
from ppo_model import PPONetwork, ActionEncoder
|
||
|
||
# 测试 ActionEncoder 本身
|
||
encoder = ActionEncoder(embedding_dim=64)
|
||
encoder.eval()
|
||
|
||
# 测试几个 action IDs
|
||
test_ids = torch.tensor([676, 721, 469771779, 536880643], dtype=torch.int64)
|
||
|
||
print("测试 ActionEncoder.encode():")
|
||
print(f"输入 action IDs: {test_ids}")
|
||
|
||
with torch.no_grad():
|
||
embeddings = encoder.encode(test_ids)
|
||
print(f"输出 embeddings shape: {embeddings.shape}")
|
||
print(f"第一个 embedding (前10维): {embeddings[0, :10]}")
|
||
print(f"第三个 embedding (前10维): {embeddings[2, :10]}")
|
||
|
||
# 现在测试完整模型的 action encoding 部分
|
||
print("\n" + "="*80)
|
||
print("加载完整 PTH 模型测试")
|
||
print("="*80)
|
||
|
||
pth_path = 'models/ppo_model_best.pth'
|
||
checkpoint = torch.load(pth_path, map_location='cpu', weights_only=False)
|
||
state_dict = checkpoint['policy_state_dict']
|
||
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
|
||
|
||
model = PPONetwork(state_dim=572, hidden_dim=hidden_dim)
|
||
model.load_state_dict(state_dict)
|
||
model.eval()
|
||
|
||
print(f"模型加载成功,hidden_dim={hidden_dim}")
|
||
|
||
# 测试 action encoder
|
||
with torch.no_grad():
|
||
embeddings_from_model = model.action_encoder.encode(test_ids)
|
||
print(f"\n从完整模型的 action_encoder 编码:")
|
||
print(f"第一个 embedding (前10维): {embeddings_from_model[0, :10]}")
|
||
print(f"第三个 embedding (前10维): {embeddings_from_model[2, :10]}")
|
||
|
||
# 检查 encoder 的权重
|
||
print(f"\nActionEncoder 第一层权重统计:")
|
||
first_layer_weight = model.action_encoder.encoder[0].weight
|
||
print(f" 形状: {first_layer_weight.shape}")
|
||
print(f" 均值: {first_layer_weight.mean().item():.6f}")
|
||
print(f" 标准差: {first_layer_weight.std().item():.6f}")
|
||
print(f" 最小值: {first_layer_weight.min().item():.6f}")
|
||
print(f" 最大值: {first_layer_weight.max().item():.6f}")
|
||
|