544 lines
22 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
* @Author: 白哉
* @Description:
* @Date: 2025年11月25日 星期二 15:11:13
* @Modify:
*/
using System;
using System.Collections.Generic;
using Logic.Action;
using Logic.AI;
using Logic.CrashSight;
using Logic.Pool;
using MemoryPack;
using RuntimeData;
using TH1_Logic.Core;
using TH1_Logic.Tools;
using UnityEngine;
namespace TH1_Logic.AITrain
{
public class TrainingState
{
public static TrainingState Instance = new TrainingState();
private ActionLogicIdData _actionLogicIdData;
// 初始化环境
public void Initialize()
{
Main.Instance.StartMatch();
}
// State 获取
// Map State 维度只能向后拓展
// State 的数据选取非常重要,与 Action 的数据一起决定了所有 AI 的行为,对应导向的行为必须和 State 数据相关联
// 2 + 40 + 400 + 160 + 200 = 802 维度
public float[] GetMapState(MapData map, PlayerData selfPlayer)
{
using var pooledState = THCollectionPool.GetListHandle<float>(out var state);
// 1. 己方信息
using var pooledSelfUnits = THCollectionPool.GetHashSetHandle<UnitData>(out var selfUnits);
map.GetUnitDataListByPlayerId(selfPlayer.Id, selfUnits);
using var pooledSelfCities = THCollectionPool.GetHashSetHandle<CityData>(out var selfCities);
map.GetCityDataListByPlayerId(selfPlayer.Id, selfCities);
var maxScore = map.PlayerMap.GetMaxPlayerScore();
// 回合数进度
state.Add(Mathf.Min(selfPlayer.Turn / 100f, 1));
// 我的 Player Index
state.Add(GetPlayerIndex(map, selfPlayer.Id) / 10f);
using var pooledUnits = THCollectionPool.GetListHandle<UnitData>(out var units);
using var pooledCities = THCollectionPool.GetListHandle<CityData>(out var cities);
foreach (var id in selfPlayer.Sight.SightGidSet)
{
//TODO 白哉确认
if (map.GridMap.GetGridDataByGid(id,out var grid) && grid.RealUnit(map, out var unit)) units.Add(unit);
if (map.GetCityDataByGid(id, out var city)) cities.Add(city);
}
// Player state容量 10 * 4共 40 维度
for (int i = 0; i < 10; i++)
{
if (i >= map.PlayerMap.PlayerDataList.Count)
{
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
}
else
{
var player = map.PlayerMap.PlayerDataList[i];
var territory = map.GetPlayerTerritoryGridIdSet(player.Id);
// 领土总数比例
state.Add((float)territory.Count / map.MapConfig.Height / map.MapConfig.Width);
// 视野总数比例
state.Add((float)player.Sight.SightGidSet.Count / map.MapConfig.Height / map.MapConfig.Width);
// 钱
state.Add(Mathf.Min(player.PlayerCoin / 50f, 1));
// 积分
state.Add(maxScore > 0 ? player.PlayerScore / (float)maxScore : 0f);
}
}
// 小兵 state 容量 80 * 5 共 400 维度
for (int i = 0; i < 80; i++)
{
if (i >= units.Count)
{
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
}
else
{
var unit = units[i];
var grid = unit.Grid(map);
state.Add(GetUnitIndex(map, unit.Id) / 100f);
state.Add(selfUnits.Contains(unit) ? 1 : 0);
state.Add(GetGridIndex(map, grid.Id) / 500f);
state.Add(Table.Instance.UnitTypeDataAssets.GetUnitTypeInfoIndex(unit) / 100f);
state.Add(unit.GetHealthRatio());
}
}
// 城市 state 容量 40 * 4 共 160 维度
for (int i = 0; i < 40; i++)
{
if (i >= cities.Count)
{
state.Add(0);
state.Add(0);
state.Add(0);
state.Add(0);
}
else
{
var city = cities[i];
var grid = city.Grid(map);
state.Add(GetCityIndex(map, city.Id) / 100f);
state.Add(selfCities.Contains(city) ? 1 : 0);
state.Add(GetGridIndex(map, grid.Id) / 500f);
state.Add(Mathf.Min((city.Level + city.LevelExp / (float)city.Level / 2f) / 10f, 1f));
}
}
// 格子 state 容量 100 * 2 共 200 维度 (放最后方便拓展)
var gridStateCount = 0;
foreach (var id in selfPlayer.Sight.SightGidSet)
{
if (map.GridMap.GetGridDataByGid(id, out var grid)) continue;
if (grid.Resource != ResourceType.None)
{
state.Add(GetGridIndex(map, grid.Id) / 500f);
state.Add((int)grid.Resource / 100f);
gridStateCount++;
if (gridStateCount >= 100) break;
}
if (grid.SpTypeList.Count != 0)
{
foreach (var spType in grid.SpTypeList)
{
state.Add(GetGridIndex(map, grid.Id) / 500f);
state.Add((int)spType / 30f);
gridStateCount++;
if (gridStateCount >= 100) break;
}
}
if (gridStateCount >= 100) break;
}
for (int i = gridStateCount; i < 100; i++)
{
state.Add(0);
state.Add(0);
}
return state.ToArray();
}
// Action 获取
// Action 维度为固定 64 位整数 (可变长编码)
public bool GetActionBitCodec(CommonActionId actionId, CommonActionParams param, out List<float> packed)
{
LoadActionLogicIdData();
var bitCodec64Var = new AIActionPacker();
bitCodec64Var.ActionId = _actionLogicIdData.GetActionIdIndex(actionId);
if (param.PlayerId != 0) bitCodec64Var.PlayerIndex = GetPlayerIndex(param.MapData, param.PlayerId);
if (param.UnitId != 0) bitCodec64Var.UnitIndex = GetUnitIndex(param.MapData, param.UnitId);
if (param.CityId != 0) bitCodec64Var.CityIndex = GetCityIndex(param.MapData, param.CityId);
if (param.GridId != 0) bitCodec64Var.GridIndex = GetGridIndex(param.MapData, param.GridId);
if (param.TargetUnitId != 0) bitCodec64Var.TargetUnitIndex = GetUnitIndex(param.MapData, param.TargetUnitId);
if (param.TargetGridId != 0) bitCodec64Var.TargetGridIndex = GetGridIndex(param.MapData, param.TargetGridId);
if (param.TargetPlayerId != 0) bitCodec64Var.TargetPlayerIndex = GetPlayerIndex(param.MapData, param.TargetPlayerId);
return bitCodec64Var.TryPack(out packed);
}
public bool GetActionBitCodecWithoutLimit(CommonActionId actionId, CommonActionParams param, out List<float> packed)
{
LoadActionLogicIdData();
var bitCodec64Var = new AIActionPacker();
bitCodec64Var.ActionId = _actionLogicIdData.GetActionIdIndex(actionId);
if (param.PlayerId != 0) bitCodec64Var.PlayerIndex = GetPlayerIndex(param.MapData, param.PlayerId);
if (param.UnitId != 0) bitCodec64Var.UnitIndex = GetUnitIndex(param.MapData, param.UnitId);
if (param.CityId != 0) bitCodec64Var.CityIndex = GetCityIndex(param.MapData, param.CityId);
if (param.GridId != 0) bitCodec64Var.GridIndex = GetGridIndex(param.MapData, param.GridId);
if (param.TargetUnitId != 0) bitCodec64Var.TargetUnitIndex = GetUnitIndex(param.MapData, param.TargetUnitId);
if (param.TargetGridId != 0) bitCodec64Var.TargetGridIndex = GetGridIndex(param.MapData, param.TargetGridId);
if (param.TargetPlayerId != 0) bitCodec64Var.TargetPlayerIndex = GetPlayerIndex(param.MapData, param.TargetPlayerId);
return bitCodec64Var.TryPackWithOutLimit(out packed);
}
// Score 获取, Score 差值即为 Reward
public float GetMapScore(MapData mapData, PlayerData player)
{
var score = GetUnitsScore(mapData, player) + GetCityScore(mapData, player);
return score / 5f;
}
// 获取当前所有可以被执行的 Action 的 BitCodec 列表
public List<List<float>> GetAllActionBitCodecForUse(MapData mapData, PlayerData selfPlayer, out List<AIActionBase> actions)
{
var packedList = new List<List<float>>();
var actionList = AIActionGenerator.GeneratorAllActionIdsForUse(mapData, selfPlayer);
foreach (var action in actionList)
{
if (!GetActionBitCodec(action.ActionLogic.ActionId, action.Param, out var packed)) continue;
packedList.Add(packed);
}
actions = actionList;
return packedList;
}
// 获取当前所有可以被执行的 Action 的 BitCodec 列表
public List<List<float>> GetAllActionBitCodec(MapData mapData, PlayerData selfPlayer, out List<AIActionBase> actions)
{
var packedList = new List<List<float>>();
var actionList = AIActionGenerator.GeneratorAllActionIds(mapData, selfPlayer);
foreach (var action in actionList)
{
if (!GetActionBitCodec(action.ActionLogic.ActionId, action.Param, out var packed)) continue;
packedList.Add(packed);
}
actions = actionList;
return packedList;
}
// Action 反编码
public bool GetActionFromBitCodec(List<float> packed, MapData mapData, out CommonActionId actionId,
out CommonActionParams param)
{
actionId = default;
param = null;
LoadActionLogicIdData();
var packer = new AIActionPacker();
if (!packer.TryUnpack(packed)) return false;
// 将 ActionId 索引映射回 CommonActionId
if (packer.ActionId < 0 || packer.ActionId >= _actionLogicIdData.ActionIdList.Count) return false;
actionId = _actionLogicIdData.ActionIdList[packer.ActionId];
var p = new CommonActionParams
{
MapData = mapData
};
// 按索引回填对象(存在即取,不存在即返回 false
if (packer.PlayerIndex != -1)
{
if (packer.PlayerIndex < 0 || packer.PlayerIndex >= mapData.PlayerMap.PlayerDataList.Count) return false;
p.PlayerData = mapData.PlayerMap.PlayerDataList[packer.PlayerIndex];
}
if (packer.UnitIndex != -1)
{
if (packer.UnitIndex < 0 || packer.UnitIndex >= mapData.UnitMap.UnitList.Count) return false;
p.UnitData = mapData.UnitMap.UnitList[packer.UnitIndex];
}
if (packer.CityIndex != -1)
{
if (packer.CityIndex < 0 || packer.CityIndex >= mapData.CityMap.CityList.Count) return false;
p.CityData = mapData.CityMap.CityList[packer.CityIndex];
}
if (packer.GridIndex != -1)
{
if (packer.GridIndex < 0 || packer.GridIndex >= mapData.GridMap.GridList.Count) return false;
p.GridData = mapData.GridMap.GridList[packer.GridIndex];
}
if (packer.TargetUnitIndex != -1)
{
if (packer.TargetUnitIndex < 0 || packer.TargetUnitIndex >= mapData.UnitMap.UnitList.Count) return false;
p.TargetUnitData = mapData.UnitMap.UnitList[packer.TargetUnitIndex];
}
if (packer.TargetGridIndex != -1)
{
if (packer.GridIndex < 0 || packer.TargetGridIndex >= mapData.GridMap.GridList.Count) return false;
p.TargetGridData = mapData.GridMap.GridList[packer.TargetGridIndex];
}
if (packer.TargetPlayerIndex != -1)
{
if (packer.TargetPlayerIndex < 0 || packer.TargetPlayerIndex >= mapData.PlayerMap.PlayerDataList.Count)
return false;
p.TargetPlayerData = mapData.PlayerMap.PlayerDataList[packer.TargetPlayerIndex];
}
param = p;
param.MainObjectType = ActionLogicFactory.GetMainObjectType(actionId.ActionType);
param.OnParamChanged();
return true;
}
// 获取 Player 的 Index
private int GetPlayerIndex(MapData map, uint id)
{
int index = 0;
foreach (var p in map.PlayerMap.PlayerDataList)
{
if (p.Id == id) return index;
index++;
}
return index;
}
// 获取 Unit 的 Index
private int GetUnitIndex(MapData map, uint id)
{
int index = 0;
foreach (var u in map.UnitMap.UnitList)
{
if (u.Id == id) return index;
index++;
}
return index;
}
// 获取 City 的 Index
private int GetCityIndex(MapData map, uint id)
{
int index = 0;
foreach (var c in map.CityMap.CityList)
{
if (c.Id == id) return index;
index++;
}
return index;
}
// 获取 Grid 的 Index
private int GetGridIndex(MapData map, uint id)
{
return map.GridMap.GetGridIndexByGid(id);
}
// 按照 Player 的 PlayerScore 排名得分
public float GetPlayerScore(MapData mapData, PlayerData playerData)
{
var playerList = mapData.PlayerMap.PlayerDataList;
if (playerList.Count == 0) return 0f;
if (playerList.Count == 1) return 10f;
// 按照 PlayerScore 降序排序获取排名
var sortedPlayers = new List<PlayerData>(playerList);
sortedPlayers.Sort((a, b) => b.PlayerScore.CompareTo(a.PlayerScore));
int rank = 0;
for (int i = 0; i < sortedPlayers.Count; i++)
{
if (sortedPlayers[i].Id == playerData.Id)
{
rank = i;
break;
}
}
// 将排名转换为 0-10 分数,第一名 10 分,最后一名 0 分
// 使用线性插值
float score = 100f * (1f - (float)rank / (playerList.Count - 1));
score += playerData.Sight.SightGidSet.Count / 10f;
return score;
}
private float GetUnitsScore(MapData mapData, PlayerData playerData)
{
float score = 0f;
using var pooledUnits = THCollectionPool.GetHashSetHandle<UnitData>(out var units);
mapData.GetUnitDataListByPlayerId(playerData.Id, units);
foreach (var unit in mapData.UnitMap.UnitList)
{
if (!unit.IsAlive()) continue;
var unitScore = unit.GetAttackRange(mapData) + unit.GetMoveRange(mapData) +
unit.GetAllAttackValue(mapData) + unit.GetAllDefenseValue(mapData);
if (units.Contains(unit)) score += unitScore;
else score -= unitScore;
}
return score;
}
private float GetCityScore(MapData mapData, PlayerData playerData)
{
float score = 0f;
using var pooledCities = THCollectionPool.GetHashSetHandle<CityData>(out var cities);
mapData.GetCityDataListByPlayerId(playerData.Id, cities);
foreach (var city in mapData.CityMap.CityList)
{
var cityScore = city.Level * 10f + city.Territory.TerritoryArea.Count * 2;
if (cities.Contains(city)) score += cityScore;
else score -= cityScore;
}
return score;
}
private void LoadActionLogicIdData()
{
if (_actionLogicIdData != null) return;
TextAsset asset = TH1Resource.ResourceLoader.Load<TextAsset>($"CommonIdData/CommonIdData");
var data = asset?.bytes ?? Array.Empty<byte>();
_actionLogicIdData = TH1Serialization.Deserialize<ActionLogicIdData>(data) ?? new ActionLogicIdData();
}
}
[MemoryPackable]
public partial class ActionLogicIdData
{
public List<CommonActionId> ActionIdList = new();
private Dictionary<CommonActionId, int> _idIndex = new();
public void Initialize()
{
if (_idIndex.Count == ActionIdList.Count) return;
_idIndex.Clear();
for (int i = 0; i < ActionIdList.Count; i++)
{
_idIndex[ActionIdList[i]] = i;
}
}
public int GetActionIdIndex(CommonActionId actionId)
{
Initialize();
return _idIndex.GetValueOrDefault(actionId, -1);
}
public void AddActionId(CommonActionId actionId)
{
Initialize();
if (_idIndex.ContainsKey(actionId)) return;
ActionIdList.Add(actionId);
_idIndex[actionId] = ActionIdList.Count - 1;
}
}
public class AIActionPacker
{
public int ActionId;
public int PlayerIndex;
public int UnitIndex;
public int CityIndex;
public int GridIndex;
public int TargetUnitIndex;
public int TargetGridIndex;
public int TargetPlayerIndex;
public int MaxActionID = 500;
public int MaxPlayerIndex = 10;
public int MaxUnitIndex = 80;
public int MaxCityIndex = 40;
public int MaxGridIndex = 500;
public int MaxTargetUnitIndex = 80;
public int MaxTargetGridIndex = 500;
public int MaxTargetPlayerIndex = 10;
public AIActionPacker()
{
PlayerIndex = -1;
UnitIndex = -1;
CityIndex = -1;
GridIndex = -1;
TargetUnitIndex = -1;
TargetGridIndex = -1;
TargetPlayerIndex = -1;
}
// 紧凑序列化(前置存在位 + 变长位段)
public bool TryPack(out List<float> packed)
{
packed = new List<float>();
packed.Add((float)(ActionId + 1) / MaxActionID);
packed.Add((float)(PlayerIndex + 1) / MaxPlayerIndex);
packed.Add((float)(UnitIndex + 1) / MaxUnitIndex);
packed.Add((float)(CityIndex + 1) / MaxCityIndex);
packed.Add((float)(GridIndex + 1) / MaxGridIndex);
packed.Add((float)(TargetUnitIndex + 1) / MaxTargetUnitIndex);
packed.Add((float)(TargetGridIndex + 1) / MaxTargetGridIndex);
packed.Add((float)(TargetPlayerIndex + 1) / MaxTargetPlayerIndex);
foreach (var value in packed)
{
if (value > 1 || value < 0)
{
LogSystem.LogError($"数据越界!!! {ActionId} {PlayerIndex} {UnitIndex} {CityIndex} {GridIndex} {TargetUnitIndex} {TargetGridIndex} {TargetPlayerIndex}");
return false;
}
}
return true;
}
public bool TryPackWithOutLimit(out List<float> packed)
{
packed = new List<float>();
packed.Add((float)(ActionId + 1) / MaxActionID);
packed.Add((float)(PlayerIndex + 1) / MaxPlayerIndex);
packed.Add((float)(UnitIndex + 1) / MaxUnitIndex);
packed.Add((float)(CityIndex + 1) / MaxCityIndex);
packed.Add((float)(GridIndex + 1) / MaxGridIndex);
packed.Add((float)(TargetUnitIndex + 1) / MaxTargetUnitIndex);
packed.Add((float)(TargetGridIndex + 1) / MaxTargetGridIndex);
packed.Add((float)(TargetPlayerIndex + 1) / MaxTargetPlayerIndex);
return true;
}
// 反序列化(读取存在位,再按位宽依序提取)
// 对应解包,顺序必须与 TryPack 完全一致
public bool TryUnpack(List<float> packed)
{
ActionId = (int)Math.Round(packed[0] * MaxActionID) - 1;
PlayerIndex = (int)Math.Round(packed[1] * MaxPlayerIndex) - 1;
UnitIndex = (int)Math.Round(packed[2] * MaxUnitIndex) - 1;
CityIndex = (int)Math.Round(packed[3] * MaxCityIndex) - 1;
GridIndex = (int)Math.Round(packed[4] * MaxGridIndex) - 1;
TargetUnitIndex = (int)Math.Round(packed[5] * MaxTargetUnitIndex) - 1;
TargetGridIndex = (int)Math.Round(packed[6] * MaxTargetGridIndex) - 1;
TargetPlayerIndex = (int)Math.Round(packed[7] * MaxTargetPlayerIndex) - 1;
foreach (var value in packed)
{
if (value > 1 || value < 0)
{
LogSystem.LogError($"数据越界!!! {ActionId} {PlayerIndex} {UnitIndex} {CityIndex} {GridIndex} {TargetUnitIndex} {TargetGridIndex} {TargetPlayerIndex}");
return false;
}
}
return true;
}
}
}