145 lines
4.9 KiB
C#
145 lines
4.9 KiB
C#
using Microsoft.ML.OnnxRuntime;
|
||
using Microsoft.ML.OnnxRuntime.Tensors;
|
||
using System.Collections.Generic;
|
||
using System.Linq;
|
||
using UnityEngine;
|
||
|
||
namespace TH1_Logic.AITrain
|
||
{
|
||
/// <summary>
|
||
/// ONNX模型推理类
|
||
/// 基于train.py的配置: STATE_DIM=802, ACTION_DIM=8
|
||
///
|
||
/// 依赖: Microsoft.ML.OnnxRuntime
|
||
/// 安装命令: 在Unity中通过NuGet安装 Microsoft.ML.OnnxRuntime 包
|
||
/// </summary>
|
||
public class ModelInference
|
||
{
|
||
public static ModelInference Instance = new ModelInference();
|
||
|
||
private const int STATE_DIM = 802;
|
||
private const int ACTION_DIM = 8;
|
||
|
||
private InferenceSession session;
|
||
private string inputName;
|
||
private string outputName;
|
||
|
||
|
||
/// <summary>
|
||
/// 加载ONNX模型
|
||
/// </summary>
|
||
/// <param name="modelPath">ONNX模型文件路径</param>
|
||
/// <param name="useGPU">是否使用GPU (默认false,需要CUDA支持)</param>
|
||
public void LoadModel(string modelPath="AIModel/cql_model", bool useGPU = false)
|
||
{
|
||
UnloadModel();
|
||
|
||
// 配置会话选项
|
||
var options = new SessionOptions();
|
||
if (useGPU)
|
||
{
|
||
// 需要安装 Microsoft.ML.OnnxRuntime.Gpu 包
|
||
options.AppendExecutionProvider_CUDA(0);
|
||
}
|
||
|
||
// 从Resources加载模型文件(不包含扩展名和Resources/前缀)
|
||
TextAsset modelAsset = TH1Resource.ResourceLoader.Load<TextAsset>(modelPath);
|
||
if (modelAsset == null)
|
||
{
|
||
throw new System.Exception($"无法从Resources加载模型: {modelPath}");
|
||
}
|
||
|
||
// 使用字节数组创建会话
|
||
session = new InferenceSession(modelAsset.bytes, options);
|
||
|
||
// 获取输入输出节点名称
|
||
inputName = session.InputMetadata.Keys.First();
|
||
outputName = session.OutputMetadata.Keys.First();
|
||
}
|
||
|
||
/// <summary>
|
||
/// 卸载模型
|
||
/// </summary>
|
||
public void UnloadModel()
|
||
{
|
||
session?.Dispose();
|
||
session = null;
|
||
}
|
||
|
||
/// <summary>
|
||
/// 推理 - 从合法actions中选择最佳action
|
||
/// </summary>
|
||
/// <param name="state">状态向量 (802维)</param>
|
||
/// <param name="legalActions">合法动作列表,每个动作是8维float数组</param>
|
||
/// <returns>选择的最佳动作在列表中的索引,失败返回-1</returns>
|
||
public int Predict(float[] state, List<List<float>> legalActions)
|
||
{
|
||
if (session == null || state == null || state.Length != STATE_DIM)
|
||
return -1;
|
||
|
||
if (legalActions == null || legalActions.Count == 0)
|
||
return -1;
|
||
|
||
// 如果只有一个合法动作,直接返回索引0
|
||
if (legalActions.Count == 1)
|
||
return 0;
|
||
|
||
// 获取模型预测的动作
|
||
// 创建输入Tensor (1, 802)
|
||
var inputTensor = new DenseTensor<float>(new[] { 1, STATE_DIM });
|
||
for (int i = 0; i < STATE_DIM; i++)
|
||
{
|
||
inputTensor[0, i] = state[i];
|
||
}
|
||
|
||
// 创建输入容器
|
||
var inputs = new List<NamedOnnxValue>
|
||
{
|
||
NamedOnnxValue.CreateFromTensor(inputName, inputTensor)
|
||
};
|
||
|
||
// 执行推理
|
||
float[] predictedAction;
|
||
using (var results = session.Run(inputs))
|
||
{
|
||
// 获取输出
|
||
var outputTensor = results.First().AsTensor<float>();
|
||
|
||
// 提取预测的动作向量
|
||
predictedAction = new float[ACTION_DIM];
|
||
for (int i = 0; i < ACTION_DIM; i++)
|
||
{
|
||
predictedAction[i] = outputTensor[0, i];
|
||
}
|
||
}
|
||
|
||
// 从合法actions中选择与预测最接近的action
|
||
float minDistance = float.MaxValue;
|
||
int bestIndex = 0;
|
||
|
||
for (int i = 0; i < legalActions.Count; i++)
|
||
{
|
||
if (legalActions[i] == null || legalActions[i].Count != ACTION_DIM)
|
||
continue;
|
||
|
||
// 计算欧几里得距离
|
||
float distance = 0f;
|
||
for (int j = 0; j < ACTION_DIM; j++)
|
||
{
|
||
float diff = predictedAction[j] - legalActions[i][j];
|
||
distance += diff * diff;
|
||
}
|
||
|
||
if (distance < minDistance)
|
||
{
|
||
minDistance = distance;
|
||
bestIndex = i;
|
||
}
|
||
}
|
||
|
||
return bestIndex;
|
||
}
|
||
}
|
||
}
|
||
|