TH1/AITrainPython/ppo_model.py
2025-12-06 11:44:43 +08:00

277 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PPO Model Architecture
包含 Actor-Critic 网络结构
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PPONetwork(nn.Module):
"""
PPO Actor-Critic 网络
使用共享特征提取层 + 独立的 Policy Head 和 Value Head
"""
def __init__(self, state_dim=572, hidden_dim=512, action_encoding_dim=64):
"""
Args:
state_dim: 状态维度 (从训练数据自动检测通常为570-572)
hidden_dim: 隐藏层维度
action_encoding_dim: Action编码维度
"""
super(PPONetwork, self).__init__()
# 动作编码器
self.action_encoder = ActionEncoder(embedding_dim=action_encoding_dim)
# 共享特征提取网络
self.shared_net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
)
# Policy Head (Actor)
# 输入state_features + action_features
self.actor_head = nn.Sequential(
nn.Linear(hidden_dim // 2 + action_encoding_dim, hidden_dim // 4),
nn.ReLU(),
nn.Linear(hidden_dim // 4, 1) # 每个action一个logit
)
# Value Head (Critic)
# 输出状态价值
self.critic_head = nn.Sequential(
nn.Linear(hidden_dim // 2, hidden_dim // 4),
nn.ReLU(),
nn.Linear(hidden_dim // 4, 1)
)
def forward(self, state, valid_actions_mask=None):
"""
前向传播
Args:
state: (batch_size, state_dim)
valid_actions_mask: (batch_size, max_actions) bool tensor, True表示可选动作
Returns:
action_logits: (batch_size, max_actions)
state_value: (batch_size, 1)
"""
# 共享特征
shared_features = self.shared_net(state)
# 状态价值
state_value = self.critic_head(shared_features)
return shared_features, state_value
def get_action_logits(self, shared_features, action_ids):
"""
为给定的动作 IDs 计算 logits
Args:
shared_features: (batch_size, hidden_dim // 2)
action_ids: (action_count,) 动作 ID 列表 (uint64)
Returns:
action_logits: (batch_size, action_count)
"""
batch_size = shared_features.shape[0]
action_count = action_ids.shape[0]
# 编码动作 IDs: (action_count, action_encoding_dim)
action_features = self.action_encoder.encode(action_ids)
# 扩展 state features 以匹配动作数量
# (batch_size, action_count, hidden_dim // 2)
expanded_state_features = shared_features.unsqueeze(1).expand(batch_size, action_count, -1)
# 扩展 action features 以匹配 batch size
# (batch_size, action_count, action_encoding_dim)
expanded_action_features = action_features.unsqueeze(0).expand(batch_size, action_count, -1)
# 拼接: (batch_size, action_count, hidden_dim // 2 + action_encoding_dim)
combined_features = torch.cat([expanded_state_features, expanded_action_features], dim=-1)
# Reshape 用于通过 actor_head
# (batch_size * action_count, hidden_dim // 2 + action_encoding_dim)
flat_features = combined_features.reshape(batch_size * action_count, -1)
# 计算 logits: (batch_size * action_count, 1)
flat_logits = self.actor_head(flat_features)
# 数值稳定性处理与evaluate方法保持一致
flat_logits = torch.clamp(flat_logits, min=-50.0, max=50.0)
# Reshape 回来: (batch_size, action_count)
action_logits = flat_logits.reshape(batch_size, action_count)
# 始终返回 (batch_size, action_count) 以保证 ONNX trace 正确
# 调用方需要根据需要自行 squeeze
return action_logits
def get_action_probs(self, state, action_ids):
"""
获取动作概率分布
Args:
state: (1, state_dim) or (batch_size, state_dim)
action_ids: (num_actions,) 动作 ID 列表 (uint64)
Returns:
probs: (num_actions,) if batch_size==1, else (batch_size, num_actions)
"""
shared_features, _ = self.forward(state)
action_logits = self.get_action_logits(shared_features, action_ids)
# action_logits: (batch_size, num_actions)
# 如果 batch_size == 1squeeze 以返回 (num_actions,)
if action_logits.shape[0] == 1:
action_logits = action_logits.squeeze(0)
probs = F.softmax(action_logits, dim=-1)
return probs
def evaluate(self, state, action_index, valid_actions_ids_list, action_lengths=None):
"""
评估状态-动作对(使用动作 ID- 完全向量化版本最大化GPU利用率
Args:
state: (batch_size, state_dim)
action_index: (batch_size,) 选择的动作在 valid_actions 中的索引
valid_actions_ids_list: list of tensors每个样本的有效动作 IDs
action_lengths: (batch_size,) tensor预计算的每个样本的动作数量可选
Returns:
log_probs: (batch_size,)
state_values: (batch_size,)
entropy: (batch_size,)
"""
shared_features, state_values = self.forward(state)
batch_size = state.shape[0]
# 步骤1拼接所有action_ids纯GPU操作
all_action_ids = torch.cat(valid_actions_ids_list)
# 步骤2使用预计算的lengths或现场计算
if action_lengths is not None:
lengths = action_lengths
else:
# Fallback现场计算仅用于推理时
lengths = torch.stack([va.new_tensor(va.shape[0]) for va in valid_actions_ids_list])
max_actions = lengths.max()
# 步骤3生成sample_indices完全向量化
sample_indices = torch.arange(batch_size, device=state.device).repeat_interleave(lengths)
# 步骤4批量处理纯GPU大batch计算
expanded_features = shared_features[sample_indices]
action_features = self.action_encoder.encode(all_action_ids)
combined = torch.cat([expanded_features, action_features], dim=-1)
all_logits = self.actor_head(combined).squeeze(-1)
all_logits = torch.clamp(all_logits, min=-50.0, max=50.0)
# 步骤5使用GPU并行scatter进行padding完全GPU化
# 使用empty+fill_代替full减少初始化开销
padded_logits = torch.empty((batch_size, max_actions),
dtype=all_logits.dtype, device=state.device)
padded_logits.fill_(-1e4) # float16安全的padding值
# 创建列索引使用纯tensor操作无循环
# 优化使用pad代替cat+zeros
cumsum_lengths = torch.nn.functional.pad(lengths.cumsum(0), (1, 0), value=0)
# 生成全局索引序列0,1,2,...,total_actions-1
global_indices = torch.arange(all_logits.shape[0], device=state.device)
# 减去每个样本的起始位置,得到列索引
col_indices = global_indices - cumsum_lengths[sample_indices]
# 一次性scatter填充完全GPU并行
padded_logits[sample_indices, col_indices] = all_logits
# 步骤6批量计算log_softmax和entropy完全向量化
log_probs_all = F.log_softmax(padded_logits, dim=-1)
log_probs = log_probs_all.gather(1, action_index.unsqueeze(1)).squeeze(1)
# 批量entropy计算
probs = torch.exp(log_probs_all)
mask = torch.arange(max_actions, device=state.device).unsqueeze(0) < lengths.unsqueeze(1)
entropies = -(probs * log_probs_all * mask).sum(dim=-1)
return log_probs, state_values.squeeze(-1), entropies
class ActionEncoder(nn.Module):
"""
将64位action ID编码为特征向量
"""
def __init__(self, embedding_dim=64):
super(ActionEncoder, self).__init__()
self.embedding_dim = embedding_dim
# 使用线性层将位编码映射到嵌入空间
self.encoder = nn.Sequential(
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, embedding_dim),
nn.Tanh()
)
# 预计算频率表,避免动态计算导致 ONNX 不一致
# 注册为 buffer这样会被保存和加载但不会被训练
freqs = torch.tensor([2.0 ** i for i in range(32)], dtype=torch.float32)
self.register_buffer('_freqs', freqs)
def encode(self, action_ids):
"""
将 uint64 action ID 编码为特征向量
统一使用对数归一化编码,确保训练和推理时编码方式一致
【重要】此方法在PTH训练和ONNX推理时必须完全一致
1. 使用预注册的_freqs buffer不动态计算
2. 固定的编码顺序sin0,cos0,sin1,cos1,...
3. 相同的归一化常数23.0
Args:
action_ids: (N,) uint64 tensor
Returns:
embeddings: (N, embedding_dim)
"""
# 确保是一维tensor并转换为long类型
action_ids_flat = action_ids.reshape(-1).long()
N = action_ids_flat.shape[0]
# 统一编码方式训练和ONNX推理都使用相同逻辑
# ActionId范围可能很大0到2^32使用对数压缩到合理范围
# log(1 + id) 将大范围的ID压缩到 [0, ~22.18]
log_ids = torch.log(action_ids_flat.float() + 1.0) # +1避免log(0)
# 归一化到 [0, 1] 范围固定常数23.0
normalized_ids = log_ids / 23.0
# 创建多维特征编码,增加区分度
# 使用不同频率的正弦/余弦编码类似Transformer的位置编码
# 使用预计算的频率表,确保 PTH 和 ONNX 完全一致
# 扩展维度: (N, 1) * (32,) -> (N, 32)
angles = normalized_ids.unsqueeze(1) * self._freqs.unsqueeze(0)
# 计算 sin 和 cos
sin_features = torch.sin(angles) # (N, 32)
cos_features = torch.cos(angles) # (N, 32)
# 交错拼接: sin0, cos0, sin1, cos1, ..., sin31, cos31
# 必须保持与训练时完全一致的顺序!
features = []
for i in range(32):
features.append(sin_features[:, i])
features.append(cos_features[:, i])
# 拼接所有特征: (N, 64)
position_encoded = torch.stack(features, dim=1)
# 批量通过编码器: (N, 64) -> (N, embedding_dim)
embeddings = self.encoder(position_encoded)
return embeddings