This commit is contained in:
wuwenbo 2025-12-17 18:11:37 +08:00
parent 246b65869b
commit 9c8fd9d5cc
7 changed files with 107 additions and 36 deletions

View File

@ -1460,7 +1460,7 @@ namespace RuntimeData
}
// 游戏结束时获取胜利玩家 ID
public uint GetWinPlayer()
public uint GetWinPlayer(out bool isSingle)
{
uint winId = 0;
int maxScore = 0;
@ -1471,6 +1471,11 @@ namespace RuntimeData
maxScore = player.PlayerScore;
winId = player.Id;
}
isSingle = false;
var aliveCount = 0;
foreach (var player in PlayerMap.PlayerDataList) if (player.Alive) aliveCount++;
isSingle = aliveCount == 1;
return winId;
}
@ -1479,7 +1484,7 @@ namespace RuntimeData
public bool CheckIfGameEnd(out bool isWin)
{
#if GAME_AUTO_DEBUG
if (CurPlayer != null && CurPlayer.Turn > 31)
if (CurPlayer != null && CurPlayer.Turn > 40)
{
uint maxId = 0;
int maxScore = 0;

View File

@ -137,34 +137,8 @@ namespace Logic.AI
if (AILogicState == AILogicState.Playing)
{
#if ENABLE_AIMODEL
var predictState = TrainingState.Instance.GetMapState(_data.Map, _data.Map.CurPlayer);
var predictActionsId = TrainingState.Instance.GetAllActionBitCodecForUse(_data.Map, _data.Map.CurPlayer, out var predictActions);
if (predictActionsId.Count == 0)
{
AILogicState = AILogicState.Finished;
}
else
{
var predictIndex = ModelInference.Instance.Predict(predictState, predictActionsId);
if (!TrainingState.Instance.GetActionFromBitCodec(predictActionsId[predictIndex], _data.Map, out var actionId, out var param))
{
LogSystem.LogError($"反序列化 Action 失败");
}
else
{
_data.MaxAiAction = new AIActionBase(param, ActionLogicFactory.GetActionLogic(actionId));
_data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param);
LogSystem.LogWarning($"{_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}");
_data.MaxAiAction.CheckIsActionInPlayerSight();
_data.MaxAiAction.CheckIsActionDuration();
_targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f);
_isWaitFrame = false;
_data.MaxAiAction = null;
AILogicState = AILogicState.Pausing;
}
}
AIModelExecute();
return;
#endif
@ -224,6 +198,11 @@ namespace Logic.AI
{
var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
var reward = afterScore - beforeScore;
if (_data.MaxAiAction.ActionLogic.ActionId.UnitActionType == UnitActionType.Capture)
{
LogSystem.LogError($"占领!!!");
reward += 10;
}
TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, validActions.Select(x => x.ToArray()).ToArray(), packed.ToArray(), reward);
}
#endif
@ -248,6 +227,87 @@ namespace Logic.AI
}
}
private void AIModelExecute()
{
#if ENABLE_AIMODEL
var predictState = TrainingState.Instance.GetMapState(_data.Map, _data.Map.CurPlayer);
var predictActionsId = TrainingState.Instance.GetAllActionBitCodecForUse(_data.Map, _data.Map.CurPlayer, out var predictActions);
if (predictActionsId.Count == 0)
{
AILogicState = AILogicState.Finished;
}
else
{
var predictIndex = ModelInference.Instance.Predict(predictState, predictActionsId);
for (int i = 0; i < predictActions.Count; i++)
{
if (predictActions[i].ActionLogic.ActionId.UnitActionType == UnitActionType.Capture)
predictIndex = i;
}
if (!TrainingState.Instance.GetActionFromBitCodec(predictActionsId[predictIndex], _data.Map, out var actionId, out var param))
{
LogSystem.LogError($"反序列化 Action 失败");
}
else
{
_data.MaxAiAction = new AIActionBase(param, ActionLogicFactory.GetActionLogic(actionId));
#if ENABLE_TRAIN
bool isPack = TrainingState.Instance.GetActionBitCodec(_data.MaxAiAction.ActionLogic.ActionId, _data.MaxAiAction.Param, out var packed);
var curPlayer = _data.MaxAiAction.Param.MapData.CurPlayer;
var beforeScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
var state = TrainingState.Instance.GetMapState(_data.MaxAiAction.Param.MapData, curPlayer);
var validActions = TrainingState.Instance.GetAllActionBitCodec(_data.MaxAiAction.Param.MapData, curPlayer, out var actions);
if (isPack)
{
bool found = false;
foreach (var action in validActions)
{
if (action.Count != packed.Count) continue;
for (int i = 0; i < packed.Count; i++)
{
if (action[i] - packed[i] > 0.001f) break;
if (i == packed.Count - 1) found = true;
}
}
if (!found)
{
validActions.Add(packed);
LogSystem.LogError($"训练数据出错: {_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}");
}
}
#endif
_data.MaxAiAction.ActionLogic.CompleteExecute(_data.MaxAiAction.Param);
#if ENABLE_TRAIN
if (isPack)
{
var afterScore = TrainingState.Instance.GetMapScore(_data.MaxAiAction.Param.MapData, curPlayer);
var reward = afterScore - beforeScore;
if (_data.MaxAiAction.ActionLogic.ActionId.UnitActionType == UnitActionType.Capture)
{
LogSystem.LogError($"占领!!!");
reward += 10;
}
TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state, validActions.Select(x => x.ToArray()).ToArray(), packed.ToArray(), reward);
}
#endif
LogSystem.LogWarning($"{_data.MaxAiAction.ActionLogic.ActionId.GetStringLog()}");
_data.MaxAiAction.CheckIsActionInPlayerSight();
_data.MaxAiAction.CheckIsActionDuration();
_targetTime = Mathf.Max(_data.MaxAiAction.Duration, 0f);
_isWaitFrame = false;
_data.MaxAiAction = null;
AILogicState = AILogicState.Pausing;
}
}
#endif
}
public static List<AIRecord> GetCurrentAIRecords()
{
return AIRecordsDict.GetValueOrDefault(CurrentAIPlayerId);

View File

@ -50,7 +50,7 @@ namespace TH1_Logic.AITrain
}
// 记录单步数据到内存
public void RecordStep(uint playerID, float[] state, float[][] validActions, float[] selectedAction, float reward, bool done=false)
public void RecordStep(uint playerID, float[] state, float[][] validActions, float[] selectedAction, float reward)
{
if (!_episodeData.ContainsKey(playerID))
{
@ -69,14 +69,17 @@ namespace TH1_Logic.AITrain
}
// 游戏结束时一次性写入文件
public void SaveEpisode()
public void SaveEpisode(uint playerId, bool isSingle)
{
if (_episodeData.Count == 0) return;
foreach (var kv in _episodeData) kv.Value[^1].Done = true;
string timestamp = System.DateTime.Now.ToString("yyyyMMdd_HHmmss");
string uniqueId = System.Guid.NewGuid().ToString("N").Substring(0, 8);
foreach (var kv in _episodeData)
{
if (kv.Key != playerId) continue;
kv.Value[^1].Done = true;
if (isSingle) kv.Value[^1].Reward += 50;
var fileName = $"episode_{timestamp}_{uniqueId}_{kv.Key}.jsonl";
string filePath = Path.Combine(_outputDir, fileName);
StringBuilder sb = new StringBuilder();

View File

@ -364,7 +364,8 @@ namespace TH1_Logic.AITrain
mapData.GetUnitDataListByPlayerId(playerData.Id, units);
foreach (var unit in mapData.UnitMap.UnitList)
{
var unitScore = unit.Health + unit.GetAttackRange() + unit.GetMoveRange() +
if (!unit.IsAlive()) continue;
var unitScore = unit.Health + unit.GetAttackRange(mapData) + unit.GetMoveRange(mapData) +
unit.GetAllAttackValue(mapData) + unit.GetAllDefenseValue(mapData);
if (units.Contains(unit)) score += unitScore;
else score -= unitScore;

View File

@ -1903,7 +1903,7 @@ namespace Logic.Action
if (!actionParams.MapData.GetGridDataByUnitId(actionParams.UnitData.Id, out var unitGrid)) return false;
if (!actionParams.MapData.GetGridDataByUnitId(actionParams.TargetUnitData.Id, out var targetUnitGrid)) return false;
if (actionParams.MapData.GridMap.CalcDistance(unitGrid, targetUnitGrid) >
actionParams.UnitData.GetAttackRange()) return false;
actionParams.UnitData.GetAttackRange(actionParams.MapData)) return false;
if (!actionParams.UnitData.IsCanAttackAlly()) return false;
if (!actionParams.UnitData.IsCanAttackTargetAlly(actionParams.MapData, actionParams.TargetUnitData))

View File

@ -142,7 +142,7 @@ namespace Logic
var afterScore = TrainingState.Instance.GetMapScore(Main.MapData, curPlayer);
var reward = afterScore - beforeScore;
TrainingDataRecorder.Instance.RecordStep(curPlayer.Id, state,
validActions.Select(x => x.ToArray()).ToArray(), validActions[^1].ToArray(), reward, true);
validActions.Select(x => x.ToArray()).ToArray(), validActions[^1].ToArray(), reward);
#endif
}
}
@ -425,7 +425,9 @@ namespace Logic
var record = Main.MapData.ExportGameRecord();
GameRecordManager.Instance.AddRecord(record);
TrainingDataRecorder.Instance.SaveEpisode();
var playerId = Main.MapData.GetWinPlayer(out var isSingle);
TrainingDataRecorder.Instance.SaveEpisode(playerId, isSingle);
// 添加延迟退出,确保数据保存完成
_gameLogic.Main.StartCoroutine(QuitGameAfterDelay(10f));