AI更新
This commit is contained in:
parent
246b65869b
commit
9c8fd9d5cc
Binary file not shown.
@ -1460,7 +1460,7 @@ namespace RuntimeData
|
||||
}
|
||||
|
||||
// 游戏结束时获取胜利玩家 ID
|
||||
public uint GetWinPlayer()
|
||||
public uint GetWinPlayer(out bool isSingle)
|
||||
{
|
||||
uint winId = 0;
|
||||
int maxScore = 0;
|
||||
@ -1471,6 +1471,11 @@ namespace RuntimeData
|
||||
maxScore = player.PlayerScore;
|
||||
winId = player.Id;
|
||||
}
|
||||
|
||||
isSingle = false;
|
||||
var aliveCount = 0;
|
||||
foreach (var player in PlayerMap.PlayerDataList) if (player.Alive) aliveCount++;
|
||||
isSingle = aliveCount == 1;
|
||||
return winId;
|
||||
}
|
||||
|
||||
@ -1479,7 +1484,7 @@ namespace RuntimeData
|
||||
public bool CheckIfGameEnd(out bool isWin)
|
||||
{
|
||||
#if GAME_AUTO_DEBUG
|
||||
if (CurPlayer != null && CurPlayer.Turn > 31)
|
||||
if (CurPlayer != null && CurPlayer.Turn > 40)
|
||||
{
|
||||
uint maxId = 0;
|
||||
int maxScore = 0;
|
||||
|
||||
@ -137,34 +137,8 @@ namespace Logic.AI
|
||||
|
||||
if (AILogicState == AILogicState.Playing)
|
||||
{
|
||||
|
||||
#if ENABLE_AIMODEL
|
||||
var predictState = TrainingState.Instance.GetMapState(_data.Map, _data.Map.CurPlayer);
|
||||
var predictActionsId = TrainingState.Instance.GetAllActionBitCodecForUse(_data.Map, _data.Map.CurPlayer, out var predictActions);
|
||||
if (predictActionsId.Count == 0)
|
||||
{
|
||||
AILogicState = AILogicState.Finished;
|
||||
}
|
||||
else
|
||||
{
|
||||
var predictIndex = ModelInference.Instance.Predict(predictState, predictActionsId);
|
||||
if (!TrainingState.Instance.GetActionFromBitCodec(predictActionsId[predictIndex], _data.Map, out var actionId, out var param))
|
||||
{
|
||||
LogSystem.LogError($"反序列化 Action 失败");
|
||||
}
|
||||
else
|
||||
{
|
||||
_data.MaxAiAction = new AIActionBase(param, ActionLogicFactory.GetActionLogic(actionId));
|
||||
_data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param);
|
||||
LogSystem.LogWarning($"{_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}");
|
||||
_data.MaxAiAction.CheckIsActionInPlayerSight();
|
||||
_data.MaxAiAction.CheckIsActionDuration();
|
||||
_targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f);
|
||||
_isWaitFrame = false;
|
||||
_data.MaxAiAction = null;
|
||||
AILogicState = AILogicState.Pausing;
|
||||
}
|
||||
}
|
||||
AIModelExecute();
|
||||
return;
|
||||
#endif
|
||||
|
||||
@ -224,6 +198,11 @@ namespace Logic.AI
|
||||
{
|
||||
var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
|
||||
var reward = afterScore - beforeScore;
|
||||
if (_data.MaxAiAction.ActionLogic.ActionId.UnitActionType == UnitActionType.Capture)
|
||||
{
|
||||
LogSystem.LogError($"占领!!!");
|
||||
reward += 10;
|
||||
}
|
||||
TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, validActions.Select(x => x.ToArray()).ToArray(), packed.ToArray(), reward);
|
||||
}
|
||||
#endif
|
||||
@ -248,6 +227,87 @@ namespace Logic.AI
|
||||
}
|
||||
}
|
||||
|
||||
private void AIModelExecute()
|
||||
{
|
||||
#if ENABLE_AIMODEL
|
||||
var predictState = TrainingState.Instance.GetMapState(_data.Map, _data.Map.CurPlayer);
|
||||
var predictActionsId = TrainingState.Instance.GetAllActionBitCodecForUse(_data.Map, _data.Map.CurPlayer, out var predictActions);
|
||||
if (predictActionsId.Count == 0)
|
||||
{
|
||||
AILogicState = AILogicState.Finished;
|
||||
}
|
||||
else
|
||||
{
|
||||
var predictIndex = ModelInference.Instance.Predict(predictState, predictActionsId);
|
||||
for (int i = 0; i < predictActions.Count; i++)
|
||||
{
|
||||
if (predictActions[i].ActionLogic.ActionId.UnitActionType == UnitActionType.Capture)
|
||||
predictIndex = i;
|
||||
}
|
||||
if (!TrainingState.Instance.GetActionFromBitCodec(predictActionsId[predictIndex], _data.Map, out var actionId, out var param))
|
||||
{
|
||||
LogSystem.LogError($"反序列化 Action 失败");
|
||||
}
|
||||
else
|
||||
{
|
||||
_data.MaxAiAction = new AIActionBase(param, ActionLogicFactory.GetActionLogic(actionId));
|
||||
|
||||
#if ENABLE_TRAIN
|
||||
bool isPack = TrainingState.Instance.GetActionBitCodec(_data.MaxAiAction.ActionLogic.ActionId, _data.MaxAiAction.Param, out var packed);
|
||||
var curPlayer = _data.MaxAiAction.Param.MapData.CurPlayer;
|
||||
var beforeScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
|
||||
var state = TrainingState.Instance.GetMapState(_data.MaxAiAction.Param.MapData, curPlayer);
|
||||
var validActions = TrainingState.Instance.GetAllActionBitCodec(_data.MaxAiAction.Param.MapData, curPlayer, out var actions);
|
||||
if (isPack)
|
||||
{
|
||||
bool found = false;
|
||||
|
||||
foreach (var action in validActions)
|
||||
{
|
||||
if (action.Count != packed.Count) continue;
|
||||
for (int i = 0; i < packed.Count; i++)
|
||||
{
|
||||
if (action[i] - packed[i] > 0.001f) break;
|
||||
if (i == packed.Count - 1) found = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!found)
|
||||
{
|
||||
validActions.Add(packed);
|
||||
LogSystem.LogError($"训练数据出错: {_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}");
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
_data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param);
|
||||
|
||||
#if ENABLE_TRAIN
|
||||
if (isPack)
|
||||
{
|
||||
var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
|
||||
var reward = afterScore - beforeScore;
|
||||
if (_data.MaxAiAction.ActionLogic.ActionId.UnitActionType == UnitActionType.Capture)
|
||||
{
|
||||
LogSystem.LogError($"占领!!!");
|
||||
reward += 10;
|
||||
}
|
||||
TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, validActions.Select(x => x.ToArray()).ToArray(), packed.ToArray(), reward);
|
||||
}
|
||||
#endif
|
||||
|
||||
LogSystem.LogWarning($"{_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}");
|
||||
_data.MaxAiAction.CheckIsActionInPlayerSight();
|
||||
_data.MaxAiAction.CheckIsActionDuration();
|
||||
_targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f);
|
||||
_isWaitFrame = false;
|
||||
_data.MaxAiAction = null;
|
||||
AILogicState = AILogicState.Pausing;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
public static List<AIRecord> GetCurrentAIRecords()
|
||||
{
|
||||
return AIRecordsDict.GetValueOrDefault(CurrentAIPlayerId);
|
||||
|
||||
@ -50,7 +50,7 @@ namespace TH1_Logic.AITrain
|
||||
}
|
||||
|
||||
// 记录单步数据到内存
|
||||
public void RecordStep(uint playerID, float[] state, float[][] validActions, float[] selectedAction, float reward, bool done=false)
|
||||
public void RecordStep(uint playerID, float[] state, float[][] validActions, float[] selectedAction, float reward)
|
||||
{
|
||||
if (!_episodeData.ContainsKey(playerID))
|
||||
{
|
||||
@ -69,14 +69,17 @@ namespace TH1_Logic.AITrain
|
||||
}
|
||||
|
||||
// 游戏结束时一次性写入文件
|
||||
public void SaveEpisode()
|
||||
public void SaveEpisode(uint playerId, bool isSingle)
|
||||
{
|
||||
if (_episodeData.Count == 0) return;
|
||||
foreach (var kv in _episodeData) kv.Value[^1].Done = true;
|
||||
string timestamp = System.DateTime.Now.ToString("yyyyMMdd_HHmmss");
|
||||
string uniqueId = System.Guid.NewGuid().ToString("N").Substring(0, 8);
|
||||
foreach (var kv in _episodeData)
|
||||
{
|
||||
if (kv.Key != playerId) continue;
|
||||
|
||||
kv.Value[^1].Done = true;
|
||||
if (isSingle) kv.Value[^1].Reward += 50;
|
||||
var fileName = $"episode_{timestamp}_{uniqueId}_{kv.Key}.jsonl";
|
||||
string filePath = Path.Combine(_outputDir, fileName);
|
||||
StringBuilder sb = new StringBuilder();
|
||||
|
||||
@ -364,7 +364,8 @@ namespace TH1_Logic.AITrain
|
||||
mapData.GetUnitDataListByPlayerId(playerData.Id, units);
|
||||
foreach (var unit in mapData.UnitMap.UnitList)
|
||||
{
|
||||
var unitScore = unit.Health + unit.GetAttackRange() + unit.GetMoveRange() +
|
||||
if (!unit.IsAlive()) continue;
|
||||
var unitScore = unit.Health + unit.GetAttackRange(mapData) + unit.GetMoveRange(mapData) +
|
||||
unit.GetAllAttackValue(mapData) + unit.GetAllDefenseValue(mapData);
|
||||
if (units.Contains(unit)) score += unitScore;
|
||||
else score -= unitScore;
|
||||
|
||||
@ -1903,7 +1903,7 @@ namespace Logic.Action
|
||||
if (!actionParams.MapData.GetGridDataByUnitId(actionParams.UnitData.Id, out var unitGrid)) return false;
|
||||
if (!actionParams.MapData.GetGridDataByUnitId(actionParams.TargetUnitData.Id, out var targetUnitGrid)) return false;
|
||||
if (actionParams.MapData.GridMap.CalcDistance(unitGrid, targetUnitGrid) >
|
||||
actionParams.UnitData.GetAttackRange()) return false;
|
||||
actionParams.UnitData.GetAttackRange(actionParams.MapData)) return false;
|
||||
|
||||
if (!actionParams.UnitData.IsCanAttackAlly()) return false;
|
||||
if (!actionParams.UnitData.IsCanAttackTargetAlly(actionParams.MapData, actionParams.TargetUnitData))
|
||||
|
||||
@ -142,7 +142,7 @@ namespace Logic
|
||||
var afterScore = TrainingState.Instance.GetMapScore(Main.MapData, curPlayer);
|
||||
var reward = afterScore - beforeScore;
|
||||
TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state,
|
||||
validActions.Select(x => x.ToArray()).ToArray(), validActions[^1].ToArray(), reward, true);
|
||||
validActions.Select(x => x.ToArray()).ToArray(), validActions[^1].ToArray(), reward);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@ -425,7 +425,9 @@ namespace Logic
|
||||
|
||||
var record = Main.MapData.ExportGameRecord();
|
||||
GameRecordManager.Instance.AddRecord(record);
|
||||
TrainingDataRecorder.Instance.SaveEpisode();
|
||||
|
||||
var playerId = Main.MapData.GetWinPlayer(out var isSingle);
|
||||
TrainingDataRecorder.Instance.SaveEpisode(playerId, isSingle);
|
||||
|
||||
// 添加延迟退出,确保数据保存完成
|
||||
_gameLogic.Main.StartCoroutine(QuitGameAfterDelay(10f));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user