TH1/Unity/Assets/Scripts/TH1_Logic/AITrain/ModelInference.cs
2026-06-10 11:58:18 +08:00

145 lines
4.9 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
namespace TH1_Logic.AITrain
{
/// <summary>
/// ONNX模型推理类
/// 基于train.py的配置: STATE_DIM=802, ACTION_DIM=8
///
/// 依赖: Microsoft.ML.OnnxRuntime
/// 安装命令: 在Unity中通过NuGet安装 Microsoft.ML.OnnxRuntime 包
/// </summary>
public class ModelInference
{
public static ModelInference Instance = new ModelInference();
private const int STATE_DIM = 802;
private const int ACTION_DIM = 8;
private InferenceSession session;
private string inputName;
private string outputName;
/// <summary>
/// 加载ONNX模型
/// </summary>
/// <param name="modelPath">ONNX模型文件路径</param>
/// <param name="useGPU">是否使用GPU (默认false需要CUDA支持)</param>
public void LoadModel(string modelPath="AIModel/cql_model", bool useGPU = false)
{
UnloadModel();
// 配置会话选项
var options = new SessionOptions();
if (useGPU)
{
// 需要安装 Microsoft.ML.OnnxRuntime.Gpu 包
options.AppendExecutionProvider_CUDA(0);
}
// 从Resources加载模型文件不包含扩展名和Resources/前缀)
TextAsset modelAsset = TH1Resource.ResourceLoader.Load<TextAsset>(modelPath);
if (modelAsset == null)
{
throw new System.Exception($"无法从Resources加载模型: {modelPath}");
}
// 使用字节数组创建会话
session = new InferenceSession(modelAsset.bytes, options);
// 获取输入输出节点名称
inputName = session.InputMetadata.Keys.First();
outputName = session.OutputMetadata.Keys.First();
}
/// <summary>
/// 卸载模型
/// </summary>
public void UnloadModel()
{
session?.Dispose();
session = null;
}
/// <summary>
/// 推理 - 从合法actions中选择最佳action
/// </summary>
/// <param name="state">状态向量 (802维)</param>
/// <param name="legalActions">合法动作列表每个动作是8维float数组</param>
/// <returns>选择的最佳动作在列表中的索引,失败返回-1</returns>
public int Predict(float[] state, List<List<float>> legalActions)
{
if (session == null || state == null || state.Length != STATE_DIM)
return -1;
if (legalActions == null || legalActions.Count == 0)
return -1;
// 如果只有一个合法动作直接返回索引0
if (legalActions.Count == 1)
return 0;
// 获取模型预测的动作
// 创建输入Tensor (1, 802)
var inputTensor = new DenseTensor<float>(new[] { 1, STATE_DIM });
for (int i = 0; i < STATE_DIM; i++)
{
inputTensor[0, i] = state[i];
}
// 创建输入容器
var inputs = new List<NamedOnnxValue>
{
NamedOnnxValue.CreateFromTensor(inputName, inputTensor)
};
// 执行推理
float[] predictedAction;
using (var results = session.Run(inputs))
{
// 获取输出
var outputTensor = results.First().AsTensor<float>();
// 提取预测的动作向量
predictedAction = new float[ACTION_DIM];
for (int i = 0; i < ACTION_DIM; i++)
{
predictedAction[i] = outputTensor[0, i];
}
}
// 从合法actions中选择与预测最接近的action
float minDistance = float.MaxValue;
int bestIndex = 0;
for (int i = 0; i < legalActions.Count; i++)
{
if (legalActions[i] == null || legalActions[i].Count != ACTION_DIM)
continue;
// 计算欧几里得距离
float distance = 0f;
for (int j = 0; j < ACTION_DIM; j++)
{
float diff = predictedAction[j] - legalActions[i][j];
distance += diff * diff;
}
if (distance < minDistance)
{
minDistance = distance;
bestIndex = i;
}
}
return bestIndex;
}
}
}