TH1/AITrainPython/test_encoder.py
2025-12-06 11:44:43 +08:00

58 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
检查具体的 action encoding 输出是否一致
"""
import torch
import numpy as np
import onnxruntime as ort
from ppo_model import PPONetwork, ActionEncoder
# 测试 ActionEncoder 本身
encoder = ActionEncoder(embedding_dim=64)
encoder.eval()
# 测试几个 action IDs
test_ids = torch.tensor([676, 721, 469771779, 536880643], dtype=torch.int64)
print("测试 ActionEncoder.encode():")
print(f"输入 action IDs: {test_ids}")
with torch.no_grad():
embeddings = encoder.encode(test_ids)
print(f"输出 embeddings shape: {embeddings.shape}")
print(f"第一个 embedding (前10维): {embeddings[0, :10]}")
print(f"第三个 embedding (前10维): {embeddings[2, :10]}")
# 现在测试完整模型的 action encoding 部分
print("\n" + "="*80)
print("加载完整 PTH 模型测试")
print("="*80)
pth_path = 'models/ppo_model_best.pth'
checkpoint = torch.load(pth_path, map_location='cpu', weights_only=False)
state_dict = checkpoint['policy_state_dict']
hidden_dim = state_dict['shared_net.0.weight'].shape[0]
model = PPONetwork(state_dim=572, hidden_dim=hidden_dim)
model.load_state_dict(state_dict)
model.eval()
print(f"模型加载成功hidden_dim={hidden_dim}")
# 测试 action encoder
with torch.no_grad():
embeddings_from_model = model.action_encoder.encode(test_ids)
print(f"\n从完整模型的 action_encoder 编码:")
print(f"第一个 embedding (前10维): {embeddings_from_model[0, :10]}")
print(f"第三个 embedding (前10维): {embeddings_from_model[2, :10]}")
# 检查 encoder 的权重
print(f"\nActionEncoder 第一层权重统计:")
first_layer_weight = model.action_encoder.encoder[0].weight
print(f" 形状: {first_layer_weight.shape}")
print(f" 均值: {first_layer_weight.mean().item():.6f}")
print(f" 标准差: {first_layer_weight.std().item():.6f}")
print(f" 最小值: {first_layer_weight.min().item():.6f}")
print(f" 最大值: {first_layer_weight.max().item():.6f}")