AI 训练相关

This commit is contained in:
wuwenbo 2025-12-02 16:31:52 +08:00
parent 33e469e710
commit 3cd14c5ea7
19 changed files with 1011 additions and 15 deletions

View File

@ -1,14 +1,12 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
/*
* @Author:
* @Description:
* @Date: 20251202 15:12:27
* @Modify:
*/
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Emit;
using MemoryPack;
using OfficeOpenXml;

View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 7d606d8d52c8d2649b6311446fdf8a40
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

Binary file not shown.

View File

@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 5ded785d366dc8348b3895f3a6182f36
TextScriptImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -25,6 +25,7 @@ namespace RuntimeData
private Dictionary<uint, GridData> _posToGridDict;
private Dictionary<uint, GridData> _gridIdToGridDict;
private Dictionary<uint, int> _gridIdToIndexDict;
private MapConfig _mapCfg;
@ -34,6 +35,7 @@ namespace RuntimeData
GridList = new List<GridData>();
_posToGridDict = new Dictionary<uint, GridData>();
_gridIdToGridDict = new Dictionary<uint, GridData>();
_gridIdToIndexDict = new Dictionary<uint, int>();
}
public GridMapData(MapConfig mapCfg, MapIdGenerator idGenerator)
@ -42,6 +44,7 @@ namespace RuntimeData
GridList = new List<GridData>();
_posToGridDict = new Dictionary<uint, GridData>();
_gridIdToGridDict = new Dictionary<uint, GridData>();
_gridIdToIndexDict = new Dictionary<uint, int>();
for (int x = 0; x < mapCfg.Width; x++)
{
@ -53,6 +56,10 @@ namespace RuntimeData
_posToGridDict[gridData.Pos.PosId] = gridData;
}
}
for (int i = 0; i < GridList.Count; i++)
{
_gridIdToIndexDict[GridList[i].Id] = i;
}
}
public GridMapData(GridMapData copyData)
@ -61,6 +68,7 @@ namespace RuntimeData
GridList = new List<GridData>();
_posToGridDict = new Dictionary<uint, GridData>();
_gridIdToGridDict = new Dictionary<uint, GridData>();
_gridIdToIndexDict = new Dictionary<uint, int>();
foreach (var grid in copyData.GridList)
{
@ -69,6 +77,10 @@ namespace RuntimeData
_gridIdToGridDict[gridData.Id] = gridData;
_posToGridDict[gridData.Pos.PosId] = gridData;
}
for (int i = 0; i < GridList.Count; i++)
{
_gridIdToIndexDict[GridList[i].Id] = i;
}
}
public void DeepCopy(GridMapData copyData)
@ -76,6 +88,7 @@ namespace RuntimeData
_mapCfg = copyData._mapCfg;
_posToGridDict.Clear();
_gridIdToGridDict.Clear();
_gridIdToIndexDict.Clear();
for (int i = 0; i < copyData.GridList.Count; i++)
{
@ -100,6 +113,10 @@ namespace RuntimeData
_posToGridDict.Remove(GridList[i].Pos.PosId);
GridList.RemoveAt(i);
}
for (int i = 0; i < GridList.Count; i++)
{
_gridIdToIndexDict[GridList[i].Id] = i;
}
}
[MemoryPackOnDeserialized]
@ -107,13 +124,19 @@ namespace RuntimeData
{
_posToGridDict ??= new Dictionary<uint, GridData>();
_gridIdToGridDict ??= new Dictionary<uint, GridData>();
_gridIdToIndexDict??= new Dictionary<uint, int>();
_posToGridDict.Clear();
_gridIdToGridDict.Clear();
_gridIdToIndexDict.Clear();
foreach (var grid in GridList)
{
_gridIdToGridDict[grid.Id] = grid;
_posToGridDict[grid.Pos.PosId] = grid;
}
for (int i = 0; i < GridList.Count; i++)
{
_gridIdToIndexDict[GridList[i].Id] = i;
}
}
public void BindMapConfig(MapConfig cfg)
@ -131,6 +154,11 @@ namespace RuntimeData
foreach (var grid in GridList) grid.OnTurnEnd(map);
}
public int GetGridIndexByGid(uint gid)
{
return _gridIdToIndexDict.GetValueOrDefault(gid, 0);
}
// 通过 gid 获取格子数据
public bool GetGridDataByGid(uint gid, out GridData data)
{

View File

@ -273,6 +273,18 @@ namespace RuntimeData
Main.PlayerLogic.EndThisTurn(map, player);
player.OnTurnEnd(map);
}
public int GetMaxPlayerScore()
{
int maxScore = -1;
foreach (var player in PlayerDataList)
{
if (!player.Alive) continue;
if (player.PlayerScore <= maxScore) continue;
maxScore = player.PlayerScore;
}
return maxScore;
}
}

View File

@ -58,6 +58,7 @@ public class UnitTypeDataAssets : ScriptableObject
//giant+level构成string然后做成dict映射
//private Dictionary<string,UnitTypeInfo> _giantTypeDict = new Dictionary<string, UnitTypeInfo>();
private Dictionary<UnitFullType, UnitTypeInfo> _unitTypeDict = new Dictionary<UnitFullType, UnitTypeInfo>();
private Dictionary<UnitFullType, int> _unitTypeIndexDict = new Dictionary<UnitFullType, int>();
[NonSerialized]
private bool _initialized = false;
@ -66,10 +67,13 @@ public class UnitTypeDataAssets : ScriptableObject
{
if (_initialized)
return;
var index = 0;
foreach (var t in UnitTypeList)
{
var key = new UnitFullType(t.UnitType, t.GiantType, t.UnitLevel);
_unitTypeDict[key] = t;
_unitTypeIndexDict[key] = index;
index++;
}
_initialized = true;
}
@ -88,6 +92,12 @@ public class UnitTypeDataAssets : ScriptableObject
return true;
}
public int GetUnitTypeInfoIndex(UnitData unit)
{
Init();
return _unitTypeIndexDict.GetValueOrDefault(unit.UnitFullType, 0);
}
public bool GetUnitTypeInfo(UnitFullType unitType,out UnitTypeInfo unitTypeInfo)
{
Init();

View File

@ -265,6 +265,45 @@ namespace Logic.AI
}
}
public static List<AIActionBase> GeneratorAllActionIds(MapData map, PlayerData selfPlayer)
{
var data = new AICalculatorData();
data.Map = map;
data.Player = selfPlayer;
var selfUnits = new HashSet<UnitData>();
map.GetUnitDataListByPlayerId(selfPlayer.Id, selfUnits);
var selfCities = new HashSet<CityData>();
map.GetCityDataListByPlayerId(selfPlayer.Id, selfCities);
data.TargetParam.PlayerData = selfPlayer;
GeneratorActionIds(data, CommonActionType.LearnTech);
GeneratorActionIds(data, CommonActionType.StartWonder);
GeneratorActionIds(data, CommonActionType.PlayerAction);
foreach (var city in selfCities)
{
data.TargetParam.CityData = city;
GeneratorActionIds(data, CommonActionType.Gain);
GeneratorActionIds(data, CommonActionType.Build);
GeneratorActionIds(data, CommonActionType.BuildWonder);
GeneratorActionIds(data, CommonActionType.GridMisc);
GeneratorActionIds(data, CommonActionType.TrainUnit);
GeneratorActionIds(data, CommonActionType.CityLevelUpAction);
}
foreach (var unit in selfUnits)
{
data.TargetParam.UnitData = unit;
GeneratorActionIds(data, CommonActionType.UnitAction);
GeneratorActionIds(data, CommonActionType.UnitSkill);
GeneratorActionIds(data, CommonActionType.UnitMove);
GeneratorActionIds(data, CommonActionType.UnitAttack);
GeneratorActionIds(data, CommonActionType.AIParamControl);
GeneratorActionIds(data, CommonActionType.UnitAttackAlly);
}
return data.AIActions;
}
public static void GeneratorActionIds(AICalculatorData data, CommonActionType type)
{
var actions = ActionLogicFactory.GetActionLogicByType(type);
@ -397,6 +436,43 @@ namespace Logic.AI
data.AIActions.Add(new AIActionBase(data.TargetParam.GetCopyParam(), action));
}
}
if (type == CommonActionType.AIParamControl)
{
if (data.TargetParam.PlayerData == null) return;
if (data.TargetParam.UnitData == null) return;
if (!data.TargetParam.UnitData.IsAlive()) return;
foreach (var action in actions)
{
if (action.ActionId.AIParamType == AIParamControlType.AIMoney) continue;
if (!action.CheckCan(data.TargetParam)) continue;
data.AIActions.Add(new AIActionBase(data.TargetParam.GetCopyParam(), action));
}
}
if (type == CommonActionType.UnitAttackAlly)
{
if (data.TargetParam.UnitData == null) return;
if (!data.TargetParam.UnitData.Alive) return;
if (data.TargetParam.UnitData.AP <= 0) return;
data.TargetParam.MainObjectType = ActionLogicFactory.GetMainObjectType(type);
foreach (var unit in data.TargetParam.MapData.UnitMap.UnitList)
{
if (unit.Id == data.TargetParam.UnitData.Id) continue;
if (!data.TargetParam.MapData.IsLeagueUnitByUnit(unit.Id, data.TargetParam.UnitData.Id)) continue;
data.TargetParam.TargetUnitData = unit;
data.TargetParam.OnParamChanged();
foreach (var action in actions)
{
if (!action.CheckCan(data.TargetParam)) continue;
data.AIActions.Add(new AIActionBase(data.TargetParam.GetCopyParam(), action));
}
}
}
}
}
}

View File

@ -13,6 +13,7 @@ using NodeCanvas.Framework;
using UnityEngine;
using RuntimeData;
using TH1_Core.Managers;
using TH1_Logic.AITrain;
namespace Logic.AI
@ -114,9 +115,12 @@ namespace Logic.AI
public void Update()
{
if (AILogicState == AILogicState.Finished || AILogicState == AILogicState.Prepare) return;
#if ENABLE_SPEEDUP
if (AILogicState == AILogicState.Pausing)
#else
if (AILogicState == AILogicState.Pausing && !PresentationManager.Busy)
//if (AILogicState == AILogicState.Pausing)
#endif
{
_targetTime -= Time.deltaTime;
if (_targetTime <= 0) AILogicState = AILogicState.Playing;
@ -148,9 +152,31 @@ namespace Logic.AI
}
if (_data.MaxAiAction == null || index > 100) AILogicState = AILogicState.Finished;
else
else
{
#if ENABLE_TRAIN
TrainingState.Instance.GetActionBitCodec(_data.MaxAiAction.ActionLogic.ActionId, _data.MaxAiAction.Param, out var packed);
var curPlayer = _data.MaxAiAction.Param.MapData.CurPlayer;
var beforeScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
var state = TrainingState.Instance.GetMapState(_data.MaxAiAction.Param.MapData, curPlayer);
var validActions = TrainingState.Instance.GetAllActionBitCodec(_data.MaxAiAction.Param.MapData, curPlayer);
#endif
_data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param);
#if ENABLE_TRAIN
if (packed != 0)
{
var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
var reward = afterScore - beforeScore;
var down = _data.MaxAiAction.Param.MapData.CheckIfGameEnd(out _);
if (down && curPlayer.Alive) reward += 1000f;
if (!curPlayer.Alive) reward -= 1000f;
TrainingDataRecorder.Instance.RecordStep(state, validActions, packed, reward, down);
}
#endif
_data.MaxAiAction.CheckIsActionDuration();
_targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f);

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 8aadc7c4125b49b7af9d56af927fd52d
timeCreated: 1764056674

View File

@ -0,0 +1,92 @@
/*
* @Author:
* @Description:
* @Date: 20251201 16:12:06
* @Modify:
*/
using System.Collections.Generic;
using System.IO;
using UnityEngine;
using System.Text;
namespace TH1_Logic.AITrain
{
[System.Serializable]
public class TrainingData
{
public float[] State;
public ulong[] Actions;
public ulong SelectedAction;
public float Reward;
public bool Done;
}
public class TrainingDataRecorder
{
public static TrainingDataRecorder Instance = new TrainingDataRecorder();
private List<TrainingData> _episodeData = new List<TrainingData>();
private string _outputDir;
private TrainingState _trainState;
public TrainingDataRecorder(string outputDir = "TrainingData")
{
_trainState = new TrainingState();
_outputDir = Path.Combine(Application.persistentDataPath, outputDir);
Directory.CreateDirectory(_outputDir);
}
// 记录单步数据到内存
public void RecordStep(float[] state, ulong[] validActions, ulong selectedAction, float reward, bool done)
{
var data = new TrainingData
{
State = state,
Actions = validActions,
SelectedAction = selectedAction,
Reward = reward,
Done = done
};
_episodeData.Add(data);
}
// 游戏结束时一次性写入文件
public void SaveEpisode(string fileName = null)
{
if (_episodeData.Count == 0) return;
if (string.IsNullOrEmpty(fileName))
{
// 使用时间戳 + GUID 的前 8 位确保唯一性
string timestamp = System.DateTime.Now.ToString("yyyyMMdd_HHmmss");
string uniqueId = System.Guid.NewGuid().ToString("N").Substring(0, 8);
fileName = $"episode_{timestamp}_{uniqueId}.jsonl";
}
string filePath = Path.Combine(_outputDir, fileName);
StringBuilder sb = new StringBuilder();
foreach (var data in _episodeData)
{
string json = JsonUtility.ToJson(data);
sb.AppendLine(json);
}
File.WriteAllText(filePath, sb.ToString(), Encoding.UTF8);
Debug.Log($"训练数据已保存: {filePath}, 共 {_episodeData.Count} 步");
_episodeData.Clear();
}
// 清空当前回合数据(不保存)
public void ClearEpisode()
{
_episodeData.Clear();
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 7caddcfed4744645ae284218e3b4d1f5
timeCreated: 1764579480

View File

@ -0,0 +1,578 @@
/*
* @Author:
* @Description:
* @Date: 20251125 15:11:13
* @Modify:
*/
using System;
using System.Collections.Generic;
using Logic.Action;
using Logic.AI;
using MemoryPack;
using RuntimeData;
using TH1_Logic.Core;
using UnityEngine;
namespace TH1_Logic.AITrain
{
public class TrainingState
{
public static TrainingState Instance = new TrainingState();
private ActionLogicIdData _actionLogicIdData;
// 初始化环境
public void Initialize()
{
Main.Instance.StartGame();
}
// State 获取
// Map State 维度只能向后拓展
// State 的数据选取非常重要,与 Action 的数据一起决定了所有 AI 的行为,对应导向的行为必须和 State 数据相关联
public float[] GetMapState(MapData map, PlayerData selfPlayer)
{
List<float> state = new();
// 1. 己方信息
var selfUnits = new HashSet<UnitData>();
map.GetUnitDataListByPlayerId(selfPlayer.Id, selfUnits);
var selfCities = new HashSet<CityData>();
map.GetCityDataListByPlayerId(selfPlayer.Id, selfCities);
var maxScore = map.PlayerMap.GetMaxPlayerScore();
// 回合数进度
state.Add(Mathf.Min(selfPlayer.Turn / 100f, 1));
// 我的 Player Index
state.Add(GetPlayerIndex(map, selfPlayer) / 10f);
var units = new List<UnitData>();
var cities = new List<CityData>();
foreach (var id in selfPlayer.Sight.SightGidSet)
{
if (map.GetUnitDataByGid(id, out var unit)) units.Add(unit);
if (map.GetCityDataByGid(id, out var city)) cities.Add(city);
}
// Player state容量 10 * 4共 40 维度
for (int i = 0; i < 10; i++)
{
if (i >= map.PlayerMap.PlayerDataList.Count)
{
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
}
else
{
var player = map.PlayerMap.PlayerDataList[i];
var territory = map.GetPlayerTerritoryGridIdSet(player.Id);
// 领土总数比例
state.Add((float)territory.Count / map.MapConfig.Height / map.MapConfig.Width);
// 视野总数比例
state.Add((float)player.Sight.SightGidSet.Count / map.MapConfig.Height / map.MapConfig.Width);
// 钱
state.Add(Mathf.Min(player.PlayerWealth / 50f, 1));
// 积分
state.Add(maxScore > 0 ? player.PlayerScore / (float)maxScore : 0f);
}
}
// 小兵 state 容量 50 * 5 共 250 维度
for (int i = 0; i < 50; i++)
{
if (i >= units.Count)
{
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
}
else
{
var unit = units[i];
var grid = unit.Grid(map);
state.Add(GetUnitIndex(map, unit) / 100f);
state.Add(selfUnits.Contains(unit) ? 1 : 0);
state.Add(GetGridIndex(map, grid) / 500f);
state.Add(Table.Instance.UnitTypeDataAssets.GetUnitTypeInfoIndex(unit) / 100f);
state.Add(unit.GetHealthRatio());
}
}
// 城市 state 容量 20 * 4 共 80 维度
for (int i = 0; i < 20; i++)
{
if (i >= cities.Count)
{
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
}
else
{
var city = cities[i];
var grid = city.Grid(map);
state.Add(GetCityIndex(map, city) / 100f);
state.Add(selfCities.Contains(city) ? 1 : 0);
state.Add(GetGridIndex(map, grid) / 500f);
state.Add(Mathf.Min((city.Level + city.LevelExp / (float)city.Level / 2f) / 10f, 1f));
}
}
// 格子 state 容量 100 * 2 共 200 维度 (放最后方便拓展)
var gridStateCount = 0;
foreach (var id in selfPlayer.Sight.SightGidSet)
{
if (gridStateCount >= 100) break;
if (map.GridMap.GetGridDataByGid(id, out var grid)) continue;
if (grid.Resource != ResourceType.None)
{
state.Add(GetGridIndex(map, grid) / 500f);
state.Add((int)grid.Resource / 100f);
gridStateCount++;
if (gridStateCount >= 100) break;
}
if (grid.SpTypeList.Count != 0)
{
foreach (var spType in grid.SpTypeList)
{
state.Add(GetGridIndex(map, grid) / 500f);
state.Add((int)spType / 30f);
gridStateCount++;
if (gridStateCount >= 100) break;
}
}
if (gridStateCount >= 100) break;
}
for (int i = gridStateCount; i < 100; i++)
{
state.Add(0);
state.Add(0);
}
return state.ToArray();
}
// Action 获取
// Action 维度为固定 64 位整数 (可变长编码)
public bool GetActionBitCodec(CommonActionId actionId, CommonActionParams param, out ulong packed)
{
LoadActionLogicIdData();
var bitCodec64Var = new AIActionBitCodec64Var();
bitCodec64Var.ActionId = _actionLogicIdData.GetActionIdIndex(actionId);
if (param.PlayerData != null) bitCodec64Var.PlayerIndex = GetPlayerIndex(param.MapData, param.PlayerData);
if (param.UnitData != null) bitCodec64Var.UnitIndex = GetUnitIndex(param.MapData, param.UnitData);
if (param.CityData != null) bitCodec64Var.CityIndex = GetCityIndex(param.MapData, param.CityData);
if (param.GridData != null) bitCodec64Var.GridIndex = GetGridIndex(param.MapData, param.GridData);
if (param.TargetUnitData != null) bitCodec64Var.TargetUnitIndex = GetUnitIndex(param.MapData, param.TargetUnitData);
if (param.TargetGridData != null) bitCodec64Var.TargetGridIndex = GetGridIndex(param.MapData, param.TargetGridData);
if (param.TargetPlayerData != null) bitCodec64Var.TargetPlayerIndex = GetPlayerIndex(param.MapData, param.TargetPlayerData);
return bitCodec64Var.TryPack(out packed);
}
// Score 获取, Score 差值即为 Reward
public float GetMapScore(MapData mapData, PlayerData player)
{
return GetPlayerScore(mapData, player) + GetUnitsScore(mapData, player) + GetCityScore(mapData, player);
}
// 获取当前所有可以被执行的 Action 的 BitCodec 列表
public ulong[] GetAllActionBitCodec(MapData mapData, PlayerData selfPlayer)
{
var packedList = new List<ulong>();
var actionList = AIActionGenerator.GeneratorAllActionIds(mapData, selfPlayer);
foreach (var action in actionList)
{
if (!GetActionBitCodec(action.ActionLogic.ActionId, action.Param, out var packed)) continue;
packedList.Add(packed);
}
packedList.Add(0);
return packedList.ToArray();
}
// Action 反编码
public bool GetActionFromBitCodec(ulong packed, MapData mapData, out CommonActionId actionId,
out CommonActionParams param)
{
actionId = default;
param = null;
LoadActionLogicIdData();
var codec = new AIActionBitCodec64Var();
if (!codec.TryUnpack(packed)) return false;
// 将 ActionId 索引映射回 CommonActionId
if (codec.ActionId < 0 || codec.ActionId >= _actionLogicIdData.ActionIdList.Count) return false;
actionId = _actionLogicIdData.ActionIdList[codec.ActionId];
var p = new CommonActionParams
{
MapData = mapData
};
// 按索引回填对象(存在即取,不存在即返回 false
if (codec.PlayerIndex != -1)
{
if (codec.PlayerIndex < 0 || codec.PlayerIndex >= mapData.PlayerMap.PlayerDataList.Count) return false;
p.PlayerData = mapData.PlayerMap.PlayerDataList[codec.PlayerIndex];
}
if (codec.UnitIndex != -1)
{
if (codec.UnitIndex < 0 || codec.UnitIndex >= mapData.UnitMap.UnitList.Count) return false;
p.UnitData = mapData.UnitMap.UnitList[codec.UnitIndex];
}
if (codec.CityIndex != -1)
{
if (codec.CityIndex < 0 || codec.CityIndex >= mapData.CityMap.CityList.Count) return false;
p.CityData = mapData.CityMap.CityList[codec.CityIndex];
}
if (codec.GridIndex != -1)
{
if (codec.GridIndex < 0 || codec.GridIndex >= mapData.GridMap.GridList.Count) return false;
p.GridData = mapData.GridMap.GridList[codec.GridIndex];
}
if (codec.TargetUnitIndex != -1)
{
if (codec.TargetUnitIndex < 0 || codec.TargetUnitIndex >= mapData.UnitMap.UnitList.Count) return false;
p.TargetUnitData = mapData.UnitMap.UnitList[codec.TargetUnitIndex];
}
if (codec.TargetGridIndex != -1)
{
if (codec.GridIndex < 0 || codec.TargetGridIndex >= mapData.GridMap.GridList.Count) return false;
p.TargetGridData = mapData.GridMap.GridList[codec.TargetGridIndex];
}
if (codec.TargetPlayerIndex != -1)
{
if (codec.TargetPlayerIndex < 0 || codec.TargetPlayerIndex >= mapData.PlayerMap.PlayerDataList.Count)
return false;
p.TargetPlayerData = mapData.PlayerMap.PlayerDataList[codec.TargetPlayerIndex];
}
param = p;
param.MainObjectType = ActionLogicFactory.GetMainObjectType(actionId.ActionType);
param.OnParamChanged();
return true;
}
// 获取 Player 的 Index
private int GetPlayerIndex(MapData map, PlayerData player)
{
int index = 0;
foreach (var p in map.PlayerMap.PlayerDataList)
{
if (p.Id == player.Id)
{
return index;
}
index++;
}
return index;
}
// 获取 Unit 的 Index
private int GetUnitIndex(MapData map, UnitData unit)
{
int index = 0;
foreach (var u in map.UnitMap.UnitList)
{
if (u.Id == unit.Id)
{
return index;
}
index++;
}
return index;
}
// 获取 City 的 Index
private int GetCityIndex(MapData map, CityData city)
{
int index = 0;
foreach (var c in map.CityMap.CityList)
{
if (c.Id == city.Id)
{
return index;
}
index++;
}
return index;
}
// 获取 Grid 的 Index
private int GetGridIndex(MapData map, GridData grid)
{
return map.GridMap.GetGridIndexByGid(grid.Id);
}
private float GetPlayerScore(MapData mapData, PlayerData playerData)
{
float score = 0f;
score += playerData.PlayerScore / 10f;
return 1000f - mapData.PlayerMap.PlayerDataList.Count * 100 + playerData.PlayerScore / 10f;
}
private float GetUnitsScore(MapData mapData, PlayerData playerData)
{
float score = 0f;
var units = new HashSet<UnitData>();
mapData.GetUnitDataListByPlayerId(playerData.Id, units);
foreach (var unit in mapData.UnitMap.UnitList)
{
var unitScore = unit.Health + unit.GetAttackRange() + unit.GetMoveRange() +
unit.GetAllAttackValue(mapData) + unit.GetAllDefenseValue(mapData);
if (units.Contains(unit)) score += unitScore;
else score -= unitScore;
}
return score / 10f;
}
private float GetCityScore(MapData mapData, PlayerData playerData)
{
float score = 0f;
var cities = new HashSet<CityData>();
mapData.GetCityDataListByPlayerId(playerData.Id, cities);
foreach (var city in mapData.CityMap.CityList)
{
if (cities.Contains(city)) score += city.Level * 10 + city.LevelExp / (float)city.Level * 4;
else score -= city.Level * 10 + city.LevelExp / (float)city.Level * 4;
}
return score / 10f;
}
private void LoadActionLogicIdData()
{
if (_actionLogicIdData != null) return;
TextAsset asset = Resources.Load<TextAsset>($"CommonIdData/CommonIdData");
var data = asset?.bytes ?? Array.Empty<byte>();
_actionLogicIdData = MemoryPackSerializer.Deserialize<ActionLogicIdData>(data) ?? new ActionLogicIdData();
}
}
[MemoryPackable]
public partial class ActionLogicIdData
{
public List<CommonActionId> ActionIdList = new();
private Dictionary<CommonActionId, int> _idIndex = new();
public void Initialize()
{
if (_idIndex.Count == ActionIdList.Count) return;
_idIndex.Clear();
for (int i = 0; i < ActionIdList.Count; i++)
{
_idIndex[ActionIdList[i]] = i;
}
}
public int GetActionIdIndex(CommonActionId actionId)
{
Initialize();
return _idIndex.GetValueOrDefault(actionId, -1);
}
public void AddActionId(CommonActionId actionId)
{
Initialize();
if (_idIndex.ContainsKey(actionId)) return;
ActionIdList.Add(actionId);
_idIndex[actionId] = ActionIdList.Count - 1;
}
}
public class AIActionBitCodec64Var
{
public int ActionId;
public int PlayerIndex;
public int UnitIndex;
public int CityIndex;
public int GridIndex;
public int TargetUnitIndex;
public int TargetGridIndex;
public int TargetPlayerIndex;
// 位宽
private const int ACTIONID_BITS = 9; // 0..500 bit 0-8
private const int PLAYER_BITS = 4; // 0..10 bit 9-15
private const int UNIT_BITS = 6; // 0..50
private const int CITY_BITS = 5; // 0..20
private const int GRID_BITS = 9; // 0..500
private const int TPLAYER_BITS = 4;
private const int TUNIT_BITS = 6;
private const int TGRID_BITS = 9;
// 存在位数量与起始(低位)
private const int PRESENCE_BITS = 7; // P,U,C,G,TU,TG,TP
private const int BIT_HAS_PLAYER = 9;
private const int BIT_HAS_UNIT = 10;
private const int BIT_HAS_CITY = 11;
private const int BIT_HAS_GRID = 12;
private const int BIT_HAS_TUNIT = 13;
private const int BIT_HAS_TGRID = 14;
private const int BIT_HAS_TPLAYER = 15;
private static bool InRange(int v, int maxInclusive) => v >= 0 && v <= maxInclusive;
public AIActionBitCodec64Var()
{
PlayerIndex = -1;
UnitIndex = -1;
CityIndex = -1;
GridIndex = -1;
TargetUnitIndex = -1;
TargetGridIndex = -1;
TargetPlayerIndex = -1;
}
// 紧凑序列化(前置存在位 + 变长位段)
public bool TryPack(out ulong packed)
{
packed = 0UL;
bool hasPlayer = PlayerIndex != -1;
bool hasUnit = UnitIndex != -1;
bool hasCity = CityIndex != -1;
bool hasGrid = GridIndex != -1;
bool hasTUnit = TargetUnitIndex != -1;
bool hasTGrid = TargetGridIndex != -1;
bool hasTPlayer = TargetPlayerIndex != -1;
// 校验:必须非负且不超过各自上限
if (!InRange(ActionId, 500)) return false;
if (hasPlayer && !InRange(PlayerIndex, 10)) return false;
if (hasUnit && !InRange(UnitIndex, 50)) return false;
if (hasCity && !InRange(CityIndex, 20)) return false;
if (hasGrid && !InRange(GridIndex, 500)) return false;
if (hasTPlayer && !InRange(TargetPlayerIndex, 10)) return false;
if (hasTUnit && !InRange(TargetUnitIndex, 50)) return false;
if (hasTGrid && !InRange(TargetGridIndex, 500)) return false;
// 1) ActionId [0..8]
ulong v = (ulong)ActionId & ((1UL << ACTIONID_BITS) - 1UL);
// 2) 存在位 [9..15]
if (hasPlayer) v |= 1UL << BIT_HAS_PLAYER;
if (hasUnit) v |= 1UL << BIT_HAS_UNIT;
if (hasCity) v |= 1UL << BIT_HAS_CITY;
if (hasGrid) v |= 1UL << BIT_HAS_GRID;
if (hasTUnit) v |= 1UL << BIT_HAS_TUNIT;
if (hasTGrid) v |= 1UL << BIT_HAS_TGRID;
if (hasTPlayer) v |= 1UL << BIT_HAS_TPLAYER;
// 3) 变长段从 bit 16 开始
int shift = ACTIONID_BITS + PRESENCE_BITS; // 16
if (hasPlayer)
{
v |= ((ulong)PlayerIndex & ((1UL << PLAYER_BITS) - 1UL)) << shift;
shift += PLAYER_BITS;
}
if (hasUnit)
{
v |= ((ulong)UnitIndex & ((1UL << UNIT_BITS) - 1UL)) << shift;
shift += UNIT_BITS;
}
if (hasCity)
{
v |= ((ulong)CityIndex & ((1UL << CITY_BITS) - 1UL)) << shift;
shift += CITY_BITS;
}
if (hasGrid)
{
v |= ((ulong)GridIndex & ((1UL << GRID_BITS) - 1UL)) << shift;
shift += GRID_BITS;
}
if (hasTUnit)
{
v |= ((ulong)TargetUnitIndex & ((1UL << TUNIT_BITS) - 1UL)) << shift;
shift += TUNIT_BITS;
}
if (hasTGrid)
{
v |= ((ulong)TargetGridIndex & ((1UL << TGRID_BITS) - 1UL)) << shift;
shift += TGRID_BITS;
}
if (hasTPlayer)
{
v |= ((ulong)TargetPlayerIndex & ((1UL << TPLAYER_BITS) - 1UL)) << shift;
shift += TPLAYER_BITS;
}
packed = v;
return true;
}
// 反序列化(读取存在位,再按位宽依序提取)
// 对应解包,顺序必须与 TryPack 完全一致
public bool TryUnpack(ulong packed)
{
ActionId = (int)(packed & ((1UL << ACTIONID_BITS) - 1UL));
if (!InRange(ActionId, 500)) return false;
bool hasPlayer = (packed & (1UL << BIT_HAS_PLAYER)) != 0;
bool hasUnit = (packed & (1UL << BIT_HAS_UNIT)) != 0;
bool hasCity = (packed & (1UL << BIT_HAS_CITY)) != 0;
bool hasGrid = (packed & (1UL << BIT_HAS_GRID)) != 0;
bool hasTUnit = (packed & (1UL << BIT_HAS_TUNIT)) != 0;
bool hasTGrid = (packed & (1UL << BIT_HAS_TGRID)) != 0;
bool hasTPlayer = (packed & (1UL << BIT_HAS_TPLAYER)) != 0;
int shift = ACTIONID_BITS + PRESENCE_BITS; // 16
int Read(int bits)
{
int val = (int)((packed >> shift) & ((1UL << bits) - 1UL));
shift += bits;
return val;
}
PlayerIndex = hasPlayer ? Read(PLAYER_BITS) : -1;
if (hasPlayer && !InRange(PlayerIndex, 10)) return false;
UnitIndex = hasUnit ? Read(UNIT_BITS) : -1;
if (hasUnit && !InRange(UnitIndex, 50)) return false;
CityIndex = hasCity ? Read(CITY_BITS) : -1;
if (hasCity && !InRange(CityIndex, 20)) return false;
GridIndex = hasGrid ? Read(GRID_BITS) : -1;
if (hasGrid && !InRange(GridIndex, 500)) return false;
TargetUnitIndex = hasTUnit ? Read(TUNIT_BITS) : -1;
if (hasTUnit && !InRange(TargetUnitIndex, 50)) return false;
TargetGridIndex = hasTGrid ? Read(TGRID_BITS) : -1;
if (hasTGrid && !InRange(TargetGridIndex, 500)) return false;
TargetPlayerIndex = hasTPlayer ? Read(TPLAYER_BITS) : -1;
if (hasTPlayer && !InRange(TargetPlayerIndex, 10)) return false;
return true;
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 199f5199fd66439f87cec581d0b58eba
timeCreated: 1764056709

View File

@ -889,6 +889,8 @@ namespace Logic.Action
if (actionType == CommonActionType.UnitMove) return MainObjectType.Unit;
if (actionType == CommonActionType.UnitAttack) return MainObjectType.Unit;
if (actionType == CommonActionType.PlayerAction) return MainObjectType.Player;
if (actionType == CommonActionType.AIParamControl) return MainObjectType.Player;
if (actionType == CommonActionType.UnitAttackAlly) return MainObjectType.Unit;
return MainObjectType.Player;
}
}
@ -984,6 +986,7 @@ namespace Logic.Action
{
return Execute(actionParams);
}
// 网络调用
public virtual bool NetCompleteExecute(CommonActionParams actionParams)
{
@ -1600,6 +1603,7 @@ namespace Logic.Action
public override bool CheckCan(CommonActionParams actionParams)
{
if (!actionParams.UnitData.IsAlive() || !actionParams.TargetUnitData.IsAlive()) return false;
if (actionParams.UnitData.Id == actionParams.TargetUnitData.Id) return false;
if (!actionParams.UnitData.CanAttackAll(actionParams.MapData) &&
actionParams.MapData.IsLeagueUnitByUnit(actionParams.UnitData.Id, actionParams.TargetUnitData.Id)) return false;
@ -1771,11 +1775,17 @@ namespace Logic.Action
public override bool CheckCan(CommonActionParams actionParams)
{
if (actionParams.UnitData.IsLimitSelfAttack(actionParams.MapData)) return false;
if (actionParams.UnitData.Id == actionParams.TargetUnitData.Id) return false;
if (!actionParams.UnitData.IsAlive() || !actionParams.TargetUnitData.IsAlive()) return false;
if (!actionParams.MapData.IsLeagueUnitByUnit(actionParams.UnitData.Id, actionParams.TargetUnitData.Id)) return false;
if (!actionParams.MapData.GetPlayerDataByUnitId(actionParams.UnitData.Id, out _)) return false;
if (!actionParams.MapData.GetPlayerDataByUnitId(actionParams.TargetUnitData.Id, out _)) return false;
if (!actionParams.MapData.GetCityDataByUnitId(actionParams.UnitData.Id, out _)) return false;
if (!actionParams.MapData.GetCityDataByUnitId(actionParams.TargetUnitData.Id, out _)) return false;
if (!actionParams.MapData.GetGridDataByUnitId(actionParams.UnitData.Id, out var unitGrid)) return false;
if (!actionParams.MapData.GetGridDataByUnitId(actionParams.TargetUnitData.Id, out var targetUnitGrid)) return false;
if (actionParams.MapData.GridMap.CalcDistance(unitGrid, targetUnitGrid) >
actionParams.UnitData.GetAttackRange()) return false;
return true;
}

View File

@ -11,6 +11,7 @@ using Logic.AI;
using Logic.Audio;
using RuntimeData;
using TH1_Core.Managers;
using TH1_Logic.AITrain;
using TH1_Logic.Core;
using TH1_Logic.Net;
using TH1_Logic.Steam;
@ -120,8 +121,24 @@ namespace Logic
// AI 行为执行完毕
if (_aiLogic.AILogicState == AILogicState.Finished)
{
#if ENABLE_TRAIN
var curPlayer = Main.MapData.CurPlayer;
var beforeScore = TrainingState.Instance.GetMapScore(Main.MapData, curPlayer);
var state = TrainingState.Instance.GetMapState(Main.MapData, curPlayer);
var validActions = TrainingState.Instance.GetAllActionBitCodec(Main.MapData, curPlayer);
#endif
Main.PlayerLogic.EndPlayerTurn(Main.MapData, _aiLogic.PlayerData.Id);
_aiLogic.FinishAILogic();
#if ENABLE_TRAIN
var afterScore = TrainingState.Instance.GetMapScore(Main.MapData, curPlayer);
var reward = afterScore - beforeScore;
var down = Main.MapData.CheckIfGameEnd(out _);
if (down && curPlayer.Alive) reward += 1000f;
if (!curPlayer.Alive) reward -= 1000f;
TrainingDataRecorder.Instance.RecordStep(state, validActions, 0, reward, down);
#endif
}
}
}
@ -403,6 +420,7 @@ namespace Logic
var record = Main.MapData.ExportGameRecord();
GameRecordManager.Instance.AddRecord(record);
TrainingDataRecorder.Instance.SaveEpisode();
}
public override void End()

View File

@ -0,0 +1,99 @@
/*
* @Author:
* @Description:
* @Date: 20250725 14:07:55
* @Modify:
*/
using System;
using System.Linq;
using Logic.Action;
using Logic.Config;
using TH1_Logic.AITrain;
using Unity.VisualScripting;
using UnityEditor;
using UnityEngine;
namespace Logic.Editor
{
public class AITrainEditorWindow : EditorWindow
{
// 背景
private GUIStyle _redBoxStyle;
private GUIStyle _whiteBoxStyle;
[MenuItem("Tools/AI训练窗口")]
private static void ShowWindow()
{
var window = GetWindow<AITrainEditorWindow>();
window.titleContent = new GUIContent("AI训练窗口");
window.Show();
}
private void OnEnable()
{
}
private void OnGUI()
{
if (_redBoxStyle == null)
{
_redBoxStyle = InspectorUtils.GetHelpBoxStyle();
InspectorUtils.AddBorder(_redBoxStyle, new Color(0.5f, 0.4f, 0.4f, 0.6f));
}
if (_whiteBoxStyle == null)
{
_whiteBoxStyle = InspectorUtils.GetHelpBoxStyle();
InspectorUtils.AddBorder(_whiteBoxStyle, new Color(1f, 1f, 1f, 0.2f));
}
if (InspectorUtils.InspectorButtonWithTextWidth($"构建 AI 行为集"))
{
ActionLogicIdData data = null;
var bytes = LoadActionLogicIdData();
if (bytes.Length > 0) data = MemoryPack.MemoryPackSerializer.Deserialize<ActionLogicIdData>(bytes);
if (data == null) data = new ActionLogicIdData();
var dict = ActionLogicFactory.GetActionLogicDict();
foreach (var actionId in dict.Keys) data.AddActionId(actionId);
SaveActionLogicIdData(MemoryPack.MemoryPackSerializer.Serialize(data));
Debug.LogError($"Action数量 {data.ActionIdList.Count}");
}
}
private byte[] LoadActionLogicIdData()
{
TextAsset asset = Resources.Load<TextAsset>($"CommonIdData/CommonIdData");
return asset?.bytes ?? Array.Empty<byte>();
}
private void SaveActionLogicIdData(byte[] data)
{
// 构建完整路径Resources 子目录)
string directory = "Assets/Resources/CommonIdData";
string filePath = $"{directory}/CommonIdData.bytes";
// 检查目录,不存在则创建
if (!System.IO.Directory.Exists(directory))
{
System.IO.Directory.CreateDirectory(directory);
}
// 写入文件(存在则覆盖,不存在则创建)
try
{
System.IO.File.WriteAllBytes(filePath, data);
// 刷新 Unity 资源数据库
AssetDatabase.Refresh();
}
catch (Exception e)
{
Debug.LogError($"写入文件失败: {e.Message}");
}
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 5b0319a8d0354ad2a6899d5f6ddaf928
timeCreated: 1764405047

View File

@ -17,6 +17,8 @@ namespace Logic.Editor
public class BuildEditor : EditorWindow
{
// 定义宏名称
public const string ENABLE_SPEEDUP = "ENABLE_SPEEDUP";
public const string ENABLE_TRAIN = "ENABLE_TRAIN";
public const string GAME_AUTO_DEBUG = "GAME_AUTO_DEBUG";
public const string STEAM_CHANNEL = "STEAM_CHANNEL";
public const string USE_INPUT = "USE_INPUT";
@ -112,10 +114,30 @@ namespace Logic.Editor
selectedVersion.Description = EditorGUILayout.TextArea(selectedVersion.Description, GUILayout.Height(60));
EditorGUILayout.EndVertical();
EditorGUILayout.BeginHorizontal();
#if ENABLE_TRAIN
InspectorUtils.InspectorTextWidthRich($"<b><color=red>训练模式已开启:</color></b>");
if (InspectorUtils.InspectorButtonWithTextWidth($"关闭训练模式")) RemoveDefine(ENABLE_TRAIN);
#else
InspectorUtils.InspectorTextWidthRich($"<b>训练模式已关闭:</b>");
if (InspectorUtils.InspectorButtonWithTextWidth($"开启训练模式")) AddDefine(ENABLE_TRAIN);
#endif
EditorGUILayout.EndHorizontal();
EditorGUILayout.BeginHorizontal();
#if ENABLE_TRAIN
InspectorUtils.InspectorTextWidthRich($"<b><color=red>加速模式已开启:</color></b>");
if (InspectorUtils.InspectorButtonWithTextWidth($"关闭加速模式")) RemoveDefine(ENABLE_SPEEDUP);
#else
InspectorUtils.InspectorTextWidthRich($"<b>加速模式已关闭:</b>");
if (InspectorUtils.InspectorButtonWithTextWidth($"开启加速模式")) AddDefine(ENABLE_SPEEDUP);
#endif
EditorGUILayout.EndHorizontal();
EditorGUILayout.BeginHorizontal();
#if GAME_AUTO_DEBUG
InspectorUtils.InspectorTextWidthRich($"<b><color=red>自动战斗已开启:</color></b>");
if (InspectorUtils.InspectorButtonWithTextWidth($"关闭自动战斗")) RemoveAutoBattle();
if (InspectorUtils.InspectorButtonWithTextWidth($"关闭自动战斗")) RemoveDefine(GAME_AUTO_DEBUG);
#else
InspectorUtils.InspectorTextWidthRich($"<b>自动战斗已关闭:</b>");
if (InspectorUtils.InspectorButtonWithTextWidth($"开启自动战斗")) AddDefine(GAME_AUTO_DEBUG);