135 lines
4.7 KiB
Python
135 lines
4.7 KiB
Python
"""
|
||
模型推理接口
|
||
用于在游戏中使用训练好的模型进行决策
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from ppo_model import PPONetwork
|
||
|
||
|
||
class PPOInference:
|
||
"""
|
||
PPO模型推理器
|
||
"""
|
||
def __init__(self, model_path, state_dim=572, hidden_dim=512, device=None):
|
||
"""
|
||
Args:
|
||
model_path: 模型文件路径
|
||
state_dim: 状态维度
|
||
hidden_dim: 隐藏层维度
|
||
device: 推理设备
|
||
"""
|
||
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
# 加载模型
|
||
self.policy = PPONetwork(state_dim=state_dim, hidden_dim=hidden_dim).to(self.device)
|
||
|
||
checkpoint = torch.load(model_path, map_location=self.device)
|
||
self.policy.load_state_dict(checkpoint['policy_state_dict'])
|
||
self.policy.eval()
|
||
|
||
print(f"模型加载成功: {model_path}")
|
||
print(f"推理设备: {self.device}")
|
||
|
||
def predict_action(self, state, valid_actions, deterministic=False):
|
||
"""
|
||
预测最佳动作
|
||
Args:
|
||
state: (state_dim,) numpy array,当前状态
|
||
valid_actions: (n_actions,) numpy array of uint64,有效动作列表
|
||
deterministic: 是否使用确定性策略(选择概率最大的动作)
|
||
Returns:
|
||
selected_action: uint64,选择的动作编码
|
||
action_probs: (n_actions,) 所有动作的概率分布
|
||
"""
|
||
with torch.no_grad():
|
||
# 转换为tensor
|
||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||
|
||
# 将valid_actions转换为tensor(关键修复:使用action_ids而不是n_actions)
|
||
action_ids_tensor = torch.from_numpy(valid_actions).to(self.device)
|
||
|
||
# 前向传播:使用action_ids计算每个action的概率
|
||
action_probs = self.policy.get_action_probs(state_tensor, action_ids_tensor)
|
||
action_probs = action_probs.cpu().numpy()
|
||
|
||
# 选择动作
|
||
n_actions = len(valid_actions)
|
||
if deterministic:
|
||
# 确定性:选择概率最大的动作
|
||
action_index = np.argmax(action_probs)
|
||
else:
|
||
# 随机采样:根据概率分布采样
|
||
action_index = np.random.choice(n_actions, p=action_probs)
|
||
|
||
selected_action = valid_actions[action_index]
|
||
|
||
return selected_action, action_probs
|
||
|
||
def predict_value(self, state):
|
||
"""
|
||
预测状态价值
|
||
Args:
|
||
state: (state_dim,) numpy array
|
||
Returns:
|
||
value: float,状态价值
|
||
"""
|
||
with torch.no_grad():
|
||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||
_, value = self.policy.forward(state_tensor)
|
||
return value.item()
|
||
|
||
def batch_predict_actions(self, states, valid_actions_list, deterministic=False):
|
||
"""
|
||
批量预测动作(用于加速推理)
|
||
Args:
|
||
states: (batch_size, state_dim) numpy array
|
||
valid_actions_list: list of numpy arrays,每个样本的有效动作
|
||
deterministic: 是否使用确定性策略
|
||
Returns:
|
||
selected_actions: list of uint64
|
||
all_probs: list of numpy arrays
|
||
"""
|
||
selected_actions = []
|
||
all_probs = []
|
||
|
||
for i in range(len(states)):
|
||
action, probs = self.predict_action(
|
||
states[i],
|
||
valid_actions_list[i],
|
||
deterministic=deterministic
|
||
)
|
||
selected_actions.append(action)
|
||
all_probs.append(probs)
|
||
|
||
return selected_actions, all_probs
|
||
|
||
|
||
# 使用示例
|
||
if __name__ == '__main__':
|
||
# 示例:加载模型并进行推理
|
||
model_path = 'models/ppo_model_best.pth'
|
||
|
||
try:
|
||
inferencer = PPOInference(model_path)
|
||
|
||
# 创建示例数据(维度从实际数据检测,通常为572)
|
||
state = np.random.randn(572).astype(np.float32)
|
||
valid_actions = np.array([0, 1, 2, 3, 4], dtype=np.uint64)
|
||
|
||
# 预测动作
|
||
selected_action, action_probs = inferencer.predict_action(state, valid_actions, deterministic=True)
|
||
|
||
print(f"\n选择的动作: {selected_action}")
|
||
print(f"动作概率分布: {action_probs}")
|
||
|
||
# 预测价值
|
||
value = inferencer.predict_value(state)
|
||
print(f"状态价值: {value:.4f}")
|
||
|
||
except FileNotFoundError:
|
||
print(f"模型文件不存在: {model_path}")
|
||
print("请先训练模型!")
|
||
|