From 9c8fd9d5ccf49619bdf4a8f3681a54be4e1e8ccd Mon Sep 17 00:00:00 2001 From: wuwenbo Date: Wed, 17 Dec 2025 18:11:37 +0800 Subject: [PATCH] =?UTF-8?q?AI=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Resources/CommonIdData/CommonIdData.bytes | Bin 17048 -> 18014 bytes Unity/Assets/Scripts/TH1_Data/MapData.cs | 9 +- Unity/Assets/Scripts/TH1_Logic/AI/AILogic.cs | 114 +++++++++++++----- .../TH1_Logic/AITrain/TrainingDataRecorder.cs | 9 +- .../TH1_Logic/AITrain/TrainingState.cs | 3 +- .../Scripts/TH1_Logic/Action/ActionLogic.cs | 2 +- .../Scripts/TH1_Logic/Core/GameLogic.cs | 6 +- 7 files changed, 107 insertions(+), 36 deletions(-) diff --git a/Unity/Assets/Resources/CommonIdData/CommonIdData.bytes b/Unity/Assets/Resources/CommonIdData/CommonIdData.bytes index 325a3d2dec2d36a99ba62ce135cb926da054b075..e32ee46b4997be9d4afc754a629efa24af2fa613 100644 GIT binary patch delta 232 zcmbQy%6PAbk&Tg+k%2*QvLlo5ne$yZKdAyPqzDpX$)8yeHTfa8?j&oDNy0oVpSL$3 z25~3fRuo|=IY066?0|SHLWJflOi3%JW-GZC|U1kRF diff --git a/Unity/Assets/Scripts/TH1_Data/MapData.cs b/Unity/Assets/Scripts/TH1_Data/MapData.cs index 2b3ff95a8..2f06f8062 100644 --- a/Unity/Assets/Scripts/TH1_Data/MapData.cs +++ b/Unity/Assets/Scripts/TH1_Data/MapData.cs @@ -1460,7 +1460,7 @@ namespace RuntimeData } // 游戏结束时获取胜利玩家 ID - public uint GetWinPlayer() + public uint GetWinPlayer(out bool isSingle) { uint winId = 0; int maxScore = 0; @@ -1471,6 +1471,11 @@ namespace RuntimeData maxScore = player.PlayerScore; winId = player.Id; } + + isSingle = false; + var aliveCount = 0; + foreach (var player in PlayerMap.PlayerDataList) if (player.Alive) aliveCount++; + isSingle = aliveCount == 1; return winId; } @@ -1479,7 +1484,7 @@ namespace RuntimeData public bool CheckIfGameEnd(out bool isWin) { #if GAME_AUTO_DEBUG - if (CurPlayer != null && CurPlayer.Turn > 31) + if (CurPlayer != null && CurPlayer.Turn > 40) { uint maxId = 0; int maxScore = 0; diff --git a/Unity/Assets/Scripts/TH1_Logic/AI/AILogic.cs b/Unity/Assets/Scripts/TH1_Logic/AI/AILogic.cs index f806d2fee..c763c0675 100644 --- a/Unity/Assets/Scripts/TH1_Logic/AI/AILogic.cs +++ b/Unity/Assets/Scripts/TH1_Logic/AI/AILogic.cs @@ -137,34 +137,8 @@ namespace Logic.AI if (AILogicState == AILogicState.Playing) { - #if ENABLE_AIMODEL - var predictState = TrainingState.Instance.GetMapState(_data.Map, _data.Map.CurPlayer); - var predictActionsId = TrainingState.Instance.GetAllActionBitCodecForUse(_data.Map, _data.Map.CurPlayer, out var predictActions); - if (predictActionsId.Count == 0) - { - AILogicState = AILogicState.Finished; - } - else - { - var predictIndex = ModelInference.Instance.Predict(predictState, predictActionsId); - if (!TrainingState.Instance.GetActionFromBitCodec(predictActionsId[predictIndex], _data.Map, out var actionId, out var param)) - { - LogSystem.LogError($"反序列化 Action 失败"); - } - else - { - _data.MaxAiAction = new AIActionBase(param, ActionLogicFactory.GetActionLogic(actionId)); - _data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param); - LogSystem.LogWarning($"{_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}"); - _data.MaxAiAction.CheckIsActionInPlayerSight(); - _data.MaxAiAction.CheckIsActionDuration(); - _targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f); - _isWaitFrame = false; - _data.MaxAiAction = null; - AILogicState = AILogicState.Pausing; - } - } + AIModelExecute(); return; #endif @@ -224,6 +198,11 @@ namespace Logic.AI { var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer); var reward = afterScore - beforeScore; + if (_data.MaxAiAction.ActionLogic.ActionId.UnitActionType == UnitActionType.Capture) + { + LogSystem.LogError($"占领!!!"); + reward += 10; + } TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, validActions.Select(x => x.ToArray()).ToArray(), packed.ToArray(), reward); } #endif @@ -248,6 +227,87 @@ namespace Logic.AI } } + private void AIModelExecute() + { +#if ENABLE_AIMODEL + var predictState = TrainingState.Instance.GetMapState(_data.Map, _data.Map.CurPlayer); + var predictActionsId = TrainingState.Instance.GetAllActionBitCodecForUse(_data.Map, _data.Map.CurPlayer, out var predictActions); + if (predictActionsId.Count == 0) + { + AILogicState = AILogicState.Finished; + } + else + { + var predictIndex = ModelInference.Instance.Predict(predictState, predictActionsId); + for (int i = 0; i < predictActions.Count; i++) + { + if (predictActions[i].ActionLogic.ActionId.UnitActionType == UnitActionType.Capture) + predictIndex = i; + } + if (!TrainingState.Instance.GetActionFromBitCodec(predictActionsId[predictIndex], _data.Map, out var actionId, out var param)) + { + LogSystem.LogError($"反序列化 Action 失败"); + } + else + { + _data.MaxAiAction = new AIActionBase(param, ActionLogicFactory.GetActionLogic(actionId)); + +#if ENABLE_TRAIN + bool isPack = TrainingState.Instance.GetActionBitCodec(_data.MaxAiAction.ActionLogic.ActionId, _data.MaxAiAction.Param, out var packed); + var curPlayer = _data.MaxAiAction.Param.MapData.CurPlayer; + var beforeScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer); + var state = TrainingState.Instance.GetMapState(_data.MaxAiAction.Param.MapData, curPlayer); + var validActions = TrainingState.Instance.GetAllActionBitCodec(_data.MaxAiAction.Param.MapData, curPlayer, out var actions); + if (isPack) + { + bool found = false; + + foreach (var action in validActions) + { + if (action.Count != packed.Count) continue; + for (int i = 0; i < packed.Count; i++) + { + if (action[i] - packed[i] > 0.001f) break; + if (i == packed.Count - 1) found = true; + } + } + + if (!found) + { + validActions.Add(packed); + LogSystem.LogError($"训练数据出错: {_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}"); + } + } +#endif + + _data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param); + +#if ENABLE_TRAIN + if (isPack) + { + var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer); + var reward = afterScore - beforeScore; + if (_data.MaxAiAction.ActionLogic.ActionId.UnitActionType == UnitActionType.Capture) + { + LogSystem.LogError($"占领!!!"); + reward += 10; + } + TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, validActions.Select(x => x.ToArray()).ToArray(), packed.ToArray(), reward); + } +#endif + + LogSystem.LogWarning($"{_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}"); + _data.MaxAiAction.CheckIsActionInPlayerSight(); + _data.MaxAiAction.CheckIsActionDuration(); + _targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f); + _isWaitFrame = false; + _data.MaxAiAction = null; + AILogicState = AILogicState.Pausing; + } + } +#endif + } + public static List GetCurrentAIRecords() { return AIRecordsDict.GetValueOrDefault(CurrentAIPlayerId); diff --git a/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingDataRecorder.cs b/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingDataRecorder.cs index a5d22ef00..92af697a5 100644 --- a/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingDataRecorder.cs +++ b/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingDataRecorder.cs @@ -50,7 +50,7 @@ namespace TH1_Logic.AITrain } // 记录单步数据到内存 - public void RecordStep(uint playerID, float[] state, float[][] validActions, float[] selectedAction, float reward, bool done=false) + public void RecordStep(uint playerID, float[] state, float[][] validActions, float[] selectedAction, float reward) { if (!_episodeData.ContainsKey(playerID)) { @@ -69,14 +69,17 @@ namespace TH1_Logic.AITrain } // 游戏结束时一次性写入文件 - public void SaveEpisode() + public void SaveEpisode(uint playerId, bool isSingle) { if (_episodeData.Count == 0) return; - foreach (var kv in _episodeData) kv.Value[^1].Done = true; string timestamp = System.DateTime.Now.ToString("yyyyMMdd_HHmmss"); string uniqueId = System.Guid.NewGuid().ToString("N").Substring(0, 8); foreach (var kv in _episodeData) { + if (kv.Key != playerId) continue; + + kv.Value[^1].Done = true; + if (isSingle) kv.Value[^1].Reward += 50; var fileName = $"episode_{timestamp}_{uniqueId}_{kv.Key}.jsonl"; string filePath = Path.Combine(_outputDir, fileName); StringBuilder sb = new StringBuilder(); diff --git a/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingState.cs b/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingState.cs index d2df34598..ed559a6d3 100644 --- a/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingState.cs +++ b/Unity/Assets/Scripts/TH1_Logic/AITrain/TrainingState.cs @@ -364,7 +364,8 @@ namespace TH1_Logic.AITrain mapData.GetUnitDataListByPlayerId(playerData.Id, units); foreach (var unit in mapData.UnitMap.UnitList) { - var unitScore = unit.Health + unit.GetAttackRange() + unit.GetMoveRange() + + if (!unit.IsAlive()) continue; + var unitScore = unit.Health + unit.GetAttackRange(mapData) + unit.GetMoveRange(mapData) + unit.GetAllAttackValue(mapData) + unit.GetAllDefenseValue(mapData); if (units.Contains(unit)) score += unitScore; else score -= unitScore; diff --git a/Unity/Assets/Scripts/TH1_Logic/Action/ActionLogic.cs b/Unity/Assets/Scripts/TH1_Logic/Action/ActionLogic.cs index 84244652d..a0e29491c 100644 --- a/Unity/Assets/Scripts/TH1_Logic/Action/ActionLogic.cs +++ b/Unity/Assets/Scripts/TH1_Logic/Action/ActionLogic.cs @@ -1903,7 +1903,7 @@ namespace Logic.Action if (!actionParams.MapData.GetGridDataByUnitId(actionParams.UnitData.Id, out var unitGrid)) return false; if (!actionParams.MapData.GetGridDataByUnitId(actionParams.TargetUnitData.Id, out var targetUnitGrid)) return false; if (actionParams.MapData.GridMap.CalcDistance(unitGrid, targetUnitGrid) > - actionParams.UnitData.GetAttackRange()) return false; + actionParams.UnitData.GetAttackRange(actionParams.MapData)) return false; if (!actionParams.UnitData.IsCanAttackAlly()) return false; if (!actionParams.UnitData.IsCanAttackTargetAlly(actionParams.MapData, actionParams.TargetUnitData)) diff --git a/Unity/Assets/Scripts/TH1_Logic/Core/GameLogic.cs b/Unity/Assets/Scripts/TH1_Logic/Core/GameLogic.cs index 487ad572e..0ec620c1a 100644 --- a/Unity/Assets/Scripts/TH1_Logic/Core/GameLogic.cs +++ b/Unity/Assets/Scripts/TH1_Logic/Core/GameLogic.cs @@ -142,7 +142,7 @@ namespace Logic var afterScore = TrainingState.Instance.GetMapScore(Main.MapData, curPlayer); var reward = afterScore - beforeScore; TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, - validActions.Select(x => x.ToArray()).ToArray(), validActions[^1].ToArray(), reward, true); + validActions.Select(x => x.ToArray()).ToArray(), validActions[^1].ToArray(), reward); #endif } } @@ -425,7 +425,9 @@ namespace Logic var record = Main.MapData.ExportGameRecord(); GameRecordManager.Instance.AddRecord(record); - TrainingDataRecorder.Instance.SaveEpisode(); + + var playerId = Main.MapData.GetWinPlayer(out var isSingle); + TrainingDataRecorder.Instance.SaveEpisode(playerId, isSingle); // 添加延迟退出,确保数据保存完成 _gameLogic.Main.StartCoroutine(QuitGameAfterDelay(10f));