TH1/AITrainPython/inference.py
2025-12-06 11:44:43 +08:00

135 lines
4.7 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型推理接口
用于在游戏中使用训练好的模型进行决策
"""
import torch
import numpy as np
from ppo_model import PPONetwork
class PPOInference:
"""
PPO模型推理器
"""
def __init__(self, model_path, state_dim=572, hidden_dim=512, device=None):
"""
Args:
model_path: 模型文件路径
state_dim: 状态维度
hidden_dim: 隐藏层维度
device: 推理设备
"""
self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型
self.policy = PPONetwork(state_dim=state_dim, hidden_dim=hidden_dim).to(self.device)
checkpoint = torch.load(model_path, map_location=self.device)
self.policy.load_state_dict(checkpoint['policy_state_dict'])
self.policy.eval()
print(f"模型加载成功: {model_path}")
print(f"推理设备: {self.device}")
def predict_action(self, state, valid_actions, deterministic=False):
"""
预测最佳动作
Args:
state: (state_dim,) numpy array当前状态
valid_actions: (n_actions,) numpy array of uint64有效动作列表
deterministic: 是否使用确定性策略(选择概率最大的动作)
Returns:
selected_action: uint64选择的动作编码
action_probs: (n_actions,) 所有动作的概率分布
"""
with torch.no_grad():
# 转换为tensor
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
# 将valid_actions转换为tensor关键修复使用action_ids而不是n_actions
action_ids_tensor = torch.from_numpy(valid_actions).to(self.device)
# 前向传播使用action_ids计算每个action的概率
action_probs = self.policy.get_action_probs(state_tensor, action_ids_tensor)
action_probs = action_probs.cpu().numpy()
# 选择动作
n_actions = len(valid_actions)
if deterministic:
# 确定性:选择概率最大的动作
action_index = np.argmax(action_probs)
else:
# 随机采样:根据概率分布采样
action_index = np.random.choice(n_actions, p=action_probs)
selected_action = valid_actions[action_index]
return selected_action, action_probs
def predict_value(self, state):
"""
预测状态价值
Args:
state: (state_dim,) numpy array
Returns:
value: float状态价值
"""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
_, value = self.policy.forward(state_tensor)
return value.item()
def batch_predict_actions(self, states, valid_actions_list, deterministic=False):
"""
批量预测动作(用于加速推理)
Args:
states: (batch_size, state_dim) numpy array
valid_actions_list: list of numpy arrays每个样本的有效动作
deterministic: 是否使用确定性策略
Returns:
selected_actions: list of uint64
all_probs: list of numpy arrays
"""
selected_actions = []
all_probs = []
for i in range(len(states)):
action, probs = self.predict_action(
states[i],
valid_actions_list[i],
deterministic=deterministic
)
selected_actions.append(action)
all_probs.append(probs)
return selected_actions, all_probs
# 使用示例
if __name__ == '__main__':
# 示例:加载模型并进行推理
model_path = 'models/ppo_model_best.pth'
try:
inferencer = PPOInference(model_path)
# 创建示例数据维度从实际数据检测通常为572
state = np.random.randn(572).astype(np.float32)
valid_actions = np.array([0, 1, 2, 3, 4], dtype=np.uint64)
# 预测动作
selected_action, action_probs = inferencer.predict_action(state, valid_actions, deterministic=True)
print(f"\n选择的动作: {selected_action}")
print(f"动作概率分布: {action_probs}")
# 预测价值
value = inferencer.predict_value(state)
print(f"状态价值: {value:.4f}")
except FileNotFoundError:
print(f"模型文件不存在: {model_path}")
print("请先训练模型!")