277 lines
11 KiB
Python
277 lines
11 KiB
Python
"""
|
||
PPO Model Architecture
|
||
包含 Actor-Critic 网络结构
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
class PPONetwork(nn.Module):
|
||
"""
|
||
PPO Actor-Critic 网络
|
||
使用共享特征提取层 + 独立的 Policy Head 和 Value Head
|
||
"""
|
||
def __init__(self, state_dim=572, hidden_dim=512, action_encoding_dim=64):
|
||
"""
|
||
Args:
|
||
state_dim: 状态维度 (从训练数据自动检测,通常为570-572)
|
||
hidden_dim: 隐藏层维度
|
||
action_encoding_dim: Action编码维度
|
||
"""
|
||
super(PPONetwork, self).__init__()
|
||
|
||
# 动作编码器
|
||
self.action_encoder = ActionEncoder(embedding_dim=action_encoding_dim)
|
||
|
||
# 共享特征提取网络
|
||
self.shared_net = nn.Sequential(
|
||
nn.Linear(state_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.LayerNorm(hidden_dim),
|
||
nn.Linear(hidden_dim, hidden_dim),
|
||
nn.ReLU(),
|
||
nn.LayerNorm(hidden_dim),
|
||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||
nn.ReLU(),
|
||
)
|
||
|
||
# Policy Head (Actor)
|
||
# 输入:state_features + action_features
|
||
self.actor_head = nn.Sequential(
|
||
nn.Linear(hidden_dim // 2 + action_encoding_dim, hidden_dim // 4),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim // 4, 1) # 每个action一个logit
|
||
)
|
||
|
||
# Value Head (Critic)
|
||
# 输出状态价值
|
||
self.critic_head = nn.Sequential(
|
||
nn.Linear(hidden_dim // 2, hidden_dim // 4),
|
||
nn.ReLU(),
|
||
nn.Linear(hidden_dim // 4, 1)
|
||
)
|
||
|
||
def forward(self, state, valid_actions_mask=None):
|
||
"""
|
||
前向传播
|
||
Args:
|
||
state: (batch_size, state_dim)
|
||
valid_actions_mask: (batch_size, max_actions) bool tensor, True表示可选动作
|
||
Returns:
|
||
action_logits: (batch_size, max_actions)
|
||
state_value: (batch_size, 1)
|
||
"""
|
||
# 共享特征
|
||
shared_features = self.shared_net(state)
|
||
|
||
# 状态价值
|
||
state_value = self.critic_head(shared_features)
|
||
|
||
return shared_features, state_value
|
||
|
||
def get_action_logits(self, shared_features, action_ids):
|
||
"""
|
||
为给定的动作 IDs 计算 logits
|
||
Args:
|
||
shared_features: (batch_size, hidden_dim // 2)
|
||
action_ids: (action_count,) 动作 ID 列表 (uint64)
|
||
Returns:
|
||
action_logits: (batch_size, action_count)
|
||
"""
|
||
batch_size = shared_features.shape[0]
|
||
action_count = action_ids.shape[0]
|
||
|
||
# 编码动作 IDs: (action_count, action_encoding_dim)
|
||
action_features = self.action_encoder.encode(action_ids)
|
||
|
||
# 扩展 state features 以匹配动作数量
|
||
# (batch_size, action_count, hidden_dim // 2)
|
||
expanded_state_features = shared_features.unsqueeze(1).expand(batch_size, action_count, -1)
|
||
|
||
# 扩展 action features 以匹配 batch size
|
||
# (batch_size, action_count, action_encoding_dim)
|
||
expanded_action_features = action_features.unsqueeze(0).expand(batch_size, action_count, -1)
|
||
|
||
# 拼接: (batch_size, action_count, hidden_dim // 2 + action_encoding_dim)
|
||
combined_features = torch.cat([expanded_state_features, expanded_action_features], dim=-1)
|
||
|
||
# Reshape 用于通过 actor_head
|
||
# (batch_size * action_count, hidden_dim // 2 + action_encoding_dim)
|
||
flat_features = combined_features.reshape(batch_size * action_count, -1)
|
||
|
||
# 计算 logits: (batch_size * action_count, 1)
|
||
flat_logits = self.actor_head(flat_features)
|
||
|
||
# 数值稳定性处理(与evaluate方法保持一致)
|
||
flat_logits = torch.clamp(flat_logits, min=-50.0, max=50.0)
|
||
|
||
# Reshape 回来: (batch_size, action_count)
|
||
action_logits = flat_logits.reshape(batch_size, action_count)
|
||
|
||
# 始终返回 (batch_size, action_count) 以保证 ONNX trace 正确
|
||
# 调用方需要根据需要自行 squeeze
|
||
return action_logits
|
||
|
||
def get_action_probs(self, state, action_ids):
|
||
"""
|
||
获取动作概率分布
|
||
Args:
|
||
state: (1, state_dim) or (batch_size, state_dim)
|
||
action_ids: (num_actions,) 动作 ID 列表 (uint64)
|
||
Returns:
|
||
probs: (num_actions,) if batch_size==1, else (batch_size, num_actions)
|
||
"""
|
||
shared_features, _ = self.forward(state)
|
||
action_logits = self.get_action_logits(shared_features, action_ids)
|
||
# action_logits: (batch_size, num_actions)
|
||
|
||
# 如果 batch_size == 1,squeeze 以返回 (num_actions,)
|
||
if action_logits.shape[0] == 1:
|
||
action_logits = action_logits.squeeze(0)
|
||
|
||
probs = F.softmax(action_logits, dim=-1)
|
||
return probs
|
||
|
||
def evaluate(self, state, action_index, valid_actions_ids_list, action_lengths=None):
|
||
"""
|
||
评估状态-动作对(使用动作 ID)- 完全向量化版本:最大化GPU利用率
|
||
Args:
|
||
state: (batch_size, state_dim)
|
||
action_index: (batch_size,) 选择的动作在 valid_actions 中的索引
|
||
valid_actions_ids_list: list of tensors,每个样本的有效动作 IDs
|
||
action_lengths: (batch_size,) tensor,预计算的每个样本的动作数量(可选)
|
||
Returns:
|
||
log_probs: (batch_size,)
|
||
state_values: (batch_size,)
|
||
entropy: (batch_size,)
|
||
"""
|
||
shared_features, state_values = self.forward(state)
|
||
batch_size = state.shape[0]
|
||
|
||
# 步骤1:拼接所有action_ids(纯GPU操作)
|
||
all_action_ids = torch.cat(valid_actions_ids_list)
|
||
|
||
# 步骤2:使用预计算的lengths或现场计算
|
||
if action_lengths is not None:
|
||
lengths = action_lengths
|
||
else:
|
||
# Fallback:现场计算(仅用于推理时)
|
||
lengths = torch.stack([va.new_tensor(va.shape[0]) for va in valid_actions_ids_list])
|
||
max_actions = lengths.max()
|
||
|
||
# 步骤3:生成sample_indices(完全向量化)
|
||
sample_indices = torch.arange(batch_size, device=state.device).repeat_interleave(lengths)
|
||
|
||
# 步骤4:批量处理(纯GPU大batch计算)
|
||
expanded_features = shared_features[sample_indices]
|
||
action_features = self.action_encoder.encode(all_action_ids)
|
||
combined = torch.cat([expanded_features, action_features], dim=-1)
|
||
all_logits = self.actor_head(combined).squeeze(-1)
|
||
all_logits = torch.clamp(all_logits, min=-50.0, max=50.0)
|
||
|
||
# 步骤5:使用GPU并行scatter进行padding(完全GPU化)
|
||
# 使用empty+fill_代替full,减少初始化开销
|
||
padded_logits = torch.empty((batch_size, max_actions),
|
||
dtype=all_logits.dtype, device=state.device)
|
||
padded_logits.fill_(-1e4) # float16安全的padding值
|
||
|
||
# 创建列索引:使用纯tensor操作(无循环)
|
||
# 优化:使用pad代替cat+zeros
|
||
cumsum_lengths = torch.nn.functional.pad(lengths.cumsum(0), (1, 0), value=0)
|
||
# 生成全局索引序列:0,1,2,...,total_actions-1
|
||
global_indices = torch.arange(all_logits.shape[0], device=state.device)
|
||
# 减去每个样本的起始位置,得到列索引
|
||
col_indices = global_indices - cumsum_lengths[sample_indices]
|
||
|
||
# 一次性scatter填充(完全GPU并行)
|
||
padded_logits[sample_indices, col_indices] = all_logits
|
||
|
||
# 步骤6:批量计算log_softmax和entropy(完全向量化)
|
||
log_probs_all = F.log_softmax(padded_logits, dim=-1)
|
||
log_probs = log_probs_all.gather(1, action_index.unsqueeze(1)).squeeze(1)
|
||
|
||
# 批量entropy计算
|
||
probs = torch.exp(log_probs_all)
|
||
mask = torch.arange(max_actions, device=state.device).unsqueeze(0) < lengths.unsqueeze(1)
|
||
entropies = -(probs * log_probs_all * mask).sum(dim=-1)
|
||
|
||
return log_probs, state_values.squeeze(-1), entropies
|
||
|
||
|
||
class ActionEncoder(nn.Module):
|
||
"""
|
||
将64位action ID编码为特征向量
|
||
"""
|
||
def __init__(self, embedding_dim=64):
|
||
super(ActionEncoder, self).__init__()
|
||
self.embedding_dim = embedding_dim
|
||
|
||
# 使用线性层将位编码映射到嵌入空间
|
||
self.encoder = nn.Sequential(
|
||
nn.Linear(64, 128),
|
||
nn.ReLU(),
|
||
nn.Linear(128, embedding_dim),
|
||
nn.Tanh()
|
||
)
|
||
|
||
# 预计算频率表,避免动态计算导致 ONNX 不一致
|
||
# 注册为 buffer,这样会被保存和加载,但不会被训练
|
||
freqs = torch.tensor([2.0 ** i for i in range(32)], dtype=torch.float32)
|
||
self.register_buffer('_freqs', freqs)
|
||
|
||
def encode(self, action_ids):
|
||
"""
|
||
将 uint64 action ID 编码为特征向量
|
||
统一使用对数归一化编码,确保训练和推理时编码方式一致
|
||
|
||
【重要】此方法在PTH训练和ONNX推理时必须完全一致:
|
||
1. 使用预注册的_freqs buffer(不动态计算)
|
||
2. 固定的编码顺序:sin0,cos0,sin1,cos1,...
|
||
3. 相同的归一化常数:23.0
|
||
|
||
Args:
|
||
action_ids: (N,) uint64 tensor
|
||
Returns:
|
||
embeddings: (N, embedding_dim)
|
||
"""
|
||
# 确保是一维tensor,并转换为long类型
|
||
action_ids_flat = action_ids.reshape(-1).long()
|
||
N = action_ids_flat.shape[0]
|
||
|
||
# 统一编码方式(训练和ONNX推理都使用相同逻辑)
|
||
# ActionId范围可能很大(0到2^32),使用对数压缩到合理范围
|
||
# log(1 + id) 将大范围的ID压缩到 [0, ~22.18]
|
||
log_ids = torch.log(action_ids_flat.float() + 1.0) # +1避免log(0)
|
||
|
||
# 归一化到 [0, 1] 范围(固定常数23.0)
|
||
normalized_ids = log_ids / 23.0
|
||
|
||
# 创建多维特征编码,增加区分度
|
||
# 使用不同频率的正弦/余弦编码(类似Transformer的位置编码)
|
||
# 使用预计算的频率表,确保 PTH 和 ONNX 完全一致
|
||
|
||
# 扩展维度: (N, 1) * (32,) -> (N, 32)
|
||
angles = normalized_ids.unsqueeze(1) * self._freqs.unsqueeze(0)
|
||
|
||
# 计算 sin 和 cos
|
||
sin_features = torch.sin(angles) # (N, 32)
|
||
cos_features = torch.cos(angles) # (N, 32)
|
||
|
||
# 交错拼接: sin0, cos0, sin1, cos1, ..., sin31, cos31
|
||
# 必须保持与训练时完全一致的顺序!
|
||
features = []
|
||
for i in range(32):
|
||
features.append(sin_features[:, i])
|
||
features.append(cos_features[:, i])
|
||
|
||
# 拼接所有特征: (N, 64)
|
||
position_encoded = torch.stack(features, dim=1)
|
||
|
||
# 批量通过编码器: (N, 64) -> (N, embedding_dim)
|
||
embeddings = self.encoder(position_encoded)
|
||
|
||
return embeddings
|
||
|