2025-12-18 17:41:05 +08:00

789 lines
30 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
d3rlpy 训练脚本 - 适配 Unity TrainingDataRecorder 格式
数据格式:
{
"State": [802个float],
"Actions": [{"data": [8个float]}, {"data": [8个float]}, ...], # 合法动作列表训练时不用使用AllActions包装
"SelectedAction": [8个float], # 专家选择的动作(归一化后)
"Reward": float,
"Done": bool # 注意:此字段会被忽略
}
Episode 划分策略:
- 每个 .jsonl 文件 = 一个完整的 Episode整局游戏
- 文件内所有样本的 Done 字段会被忽略
- 只有文件的最后一个样本会被标记为 Done=True
- 这符合4X游戏的特性只有整局游戏结束才是真正的终止状态
"""
import os
import sys
import json
import glob
import numpy as np
import d3rlpy
from d3rlpy.dataset import MDPDataset
from d3rlpy.algos import BC, IQL, CQL
# ==================== 配置 ====================
# 训练开关(修改这里选择要训练的算法)
TRAIN_BC = False # Behavior Cloning最简单、最快✅ 推荐用于测试过滤效果
TRAIN_IQL = False # Implicit Q-Learning最稳定
TRAIN_CQL = True # Conservative Q-Learning最强大
# 路径配置
# DATA_DIR = r"F:\TrainData\Test" # 数据目录
DATA_DIR = r"F:\TrainData\TrainingData" # 数据目录
MODEL_DIR = "./models" # 模型保存目录
# 训练参数
STATE_DIM = 802
ACTION_DIM = 8 # SelectedAction 是 8 个归一化的 float
BATCH_SIZE = 256
LEARNING_RATE_BC = 1e-3
LEARNING_RATE_IQL = 1e-4
LEARNING_RATE_CQL = 3e-5 # 🔧 第5次调整再次降低5e-5在Epoch 28后仍发散
LEARNING_RATE_CQL_CRITIC = 2e-5 # 🔧 Critic 也降低保持与Actor的比例
LEARNING_RATE_CQL_ALPHA = 5e-6 # 🔧 Alpha 极慢衰减原1e-5仍过快
N_EPOCHS = 25 # ✅ 100 Epochs分析结果Epoch 20是最佳点25轮足够
# 早停配置
ENABLE_EARLY_STOPPING = True # 是否启用早停
EARLY_STOP_PATIENCE = 3 # 容忍多少个epoch的性能下降
EARLY_STOP_DELTA = 20.0 # Actor Loss增长超过此值触发计数
EARLY_STOP_TEMP_THRESHOLD = 2.5 # Temperature超过此值立即停止
EARLY_STOP_ACTOR_POSITIVE = True # Actor Loss转正立即停止
# 数据集自适应训练步数(可选)
# 如果数据集很大,可以增加训练步数以充分利用数据
AUTO_ADJUST_STEPS = False # 是否根据数据集大小自动调整训练步数
# GPU 设置
USE_GPU = True # 自动检测 GPU
# ==================== 数据过滤配置 ====================
ENABLE_FILTERING = False # 是否启用数据过滤
MIN_EPISODE_REWARD = 300.0 # Episode 最低总奖励阈值(低于此值的整个 Episode 被丢弃)
MIN_EPISODE_STEPS = 1 # Episode 最少步数(低于此值的 Episode 被丢弃)
# ==================== 数据加载 ====================
def load_training_data():
"""加载训练数据(带数据过滤)"""
print("=" * 60)
print("加载训练数据")
print("=" * 60)
print(f"数据目录: {DATA_DIR}")
if ENABLE_FILTERING:
print(f"\n✅ 数据过滤已启用:")
print(f" - 最低 Episode 奖励: {MIN_EPISODE_REWARD}")
print(f" - 最少 Episode 步数: {MIN_EPISODE_STEPS}")
print(f" - 策略: 以 Episode 为单位过滤,保留高质量完整轨迹")
# 查找所有 jsonl 文件
jsonl_files = glob.glob(os.path.join(DATA_DIR, "*.jsonl"))
if not jsonl_files:
raise FileNotFoundError(f"未在 {DATA_DIR} 中找到 .jsonl 文件")
print(f"\n找到 {len(jsonl_files)} 个数据文件")
if not ENABLE_FILTERING:
# 不过滤,直接加载所有数据
return _load_all_data(jsonl_files)
# ==================== 单次扫描:边读边过滤 ====================
print("\n开始扫描:统计 Episode 质量并加载高质量数据...")
observations = []
actions = []
rewards = []
terminals = []
# 统计信息
all_episode_rewards = [] # 所有Episode的总奖励
all_episode_steps = [] # 所有Episode的步数
total_raw_samples = 0
filtered_samples = 0
total_episodes = 0
good_episodes = 0
for file_idx, file_path in enumerate(jsonl_files):
current_episode_data = [] # 缓存当前 Episode 的所有样本
current_episode_reward = 0.0 # 当前 Episode 的累计奖励
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if not line.strip():
continue
try:
data = json.loads(line)
# 解析数据
state = np.array(data['State'], dtype=np.float32)
if len(state) != STATE_DIM:
continue
selected_action = np.array(data['SelectedAction'], dtype=np.float32)
if len(selected_action) != ACTION_DIM:
continue
reward = float(data['Reward'])
# 忽略样本中的Done字段所有样本都标记为未结束
# 累计奖励
current_episode_reward += reward
# 缓存样本Done暂时设为False
current_episode_data.append((state, selected_action, reward, 0.0))
total_raw_samples += 1
except (json.JSONDecodeError, KeyError) as e:
continue
# 文件读取完毕 = Episode 结束
if len(current_episode_data) > 0:
# 将最后一个样本标记为 Done=True
last_state, last_action, last_reward, _ = current_episode_data[-1]
current_episode_data[-1] = (last_state, last_action, last_reward, 1.0)
episode_steps = len(current_episode_data)
total_episodes += 1
# 记录统计信息
all_episode_rewards.append(current_episode_reward)
all_episode_steps.append(episode_steps)
# 判断是否满足质量要求
if current_episode_reward >= MIN_EPISODE_REWARD and episode_steps >= MIN_EPISODE_STEPS:
# 保留这个 Episode 的所有样本
for s, a, r, d in current_episode_data:
observations.append(s)
actions.append(a)
rewards.append(r)
terminals.append(d)
filtered_samples += 1
good_episodes += 1
except Exception as e:
print(f" 警告: 文件 {os.path.basename(file_path)} 读取失败: {e}")
continue
if (file_idx + 1) % 100 == 0:
print(f" 已处理 {file_idx + 1}/{len(jsonl_files)} 个文件...")
# 检查是否有有效数据
if total_episodes == 0:
raise ValueError("未找到任何完整的 Episode")
if filtered_samples == 0:
raise ValueError(f"没有 Episode 满足条件(奖励>={MIN_EPISODE_REWARD} 且 步数>={MIN_EPISODE_STEPS}),请降低阈值")
# 显示统计信息
print(f"\nEpisode 质量统计:")
print(f" 总 Episode 数: {total_episodes}")
print(f" 符合条件的 Episode: {good_episodes} ({good_episodes/total_episodes*100:.1f}%)")
print(f"\n 奖励分布:")
print(f" 范围: [{min(all_episode_rewards):.2f}, {max(all_episode_rewards):.2f}]")
print(f" 平均: {np.mean(all_episode_rewards):.2f}")
print(f" 中位数: {np.median(all_episode_rewards):.2f}")
print(f" 过滤阈值: >={MIN_EPISODE_REWARD}")
print(f"\n 步数分布:")
print(f" 范围: [{min(all_episode_steps)}, {max(all_episode_steps)}]")
print(f" 平均: {np.mean(all_episode_steps):.1f}")
print(f" 中位数: {int(np.median(all_episode_steps))}")
print(f" 过滤阈值: >={MIN_EPISODE_STEPS}")
# 转换为 numpy 数组
observations = np.array(observations, dtype=np.float32)
actions = np.array(actions, dtype=np.float32)
rewards = np.array(rewards, dtype=np.float32)
terminals = np.array(terminals, dtype=np.float32)
print(f"\n过滤完成:")
print(f" 原始样本数: {total_raw_samples}")
print(f" 保留样本数: {filtered_samples}")
print(f" 过滤比例: {(1 - filtered_samples/total_raw_samples)*100:.1f}%")
print(f" 状态维度: {observations.shape}")
print(f" 动作维度: {actions.shape}")
print(f" 奖励范围: [{rewards.min():.2f}, {rewards.max():.2f}]")
print(f" Episode数: {int(terminals.sum())}")
# Reward 分布统计
print(f"\nReward 统计:")
print(f" 平均 Reward: {rewards.mean():.4f}")
print(f" Reward 标准差: {rewards.std():.4f}")
nonzero_ratio = (rewards != 0).sum() / len(rewards) * 100
print(f" 非零 Reward 比例: {nonzero_ratio:.2f}%")
if nonzero_ratio < 100:
nonzero_rewards = rewards[rewards != 0]
print(f" 非零 Reward 平均值: {nonzero_rewards.mean():.2f}")
# 检查动作是否已归一化
if actions.max() > 1.1 or actions.min() < -0.1:
print(f"\n⚠️ 警告: 动作值超出 [0,1] 范围")
print(f" 最小值: {actions.min()}")
print(f" 最大值: {actions.max()}")
return observations, actions, rewards, terminals
def _load_all_data(jsonl_files):
"""不过滤,加载所有数据(每个文件=1个Episode"""
observations = []
actions = []
rewards = []
terminals = []
total_samples = 0
total_episodes = 0
for file_idx, file_path in enumerate(jsonl_files):
file_samples = 0
file_data = [] # 缓存当前文件的所有样本
try:
with open(file_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
if not line.strip():
continue
try:
data = json.loads(line)
state = np.array(data['State'], dtype=np.float32)
if len(state) != STATE_DIM:
continue
selected_action = np.array(data['SelectedAction'], dtype=np.float32)
if len(selected_action) != ACTION_DIM:
continue
reward = float(data['Reward'])
# 忽略样本中的Done字段统一在文件末尾设置
file_data.append((state, selected_action, reward))
file_samples += 1
except (json.JSONDecodeError, KeyError) as e:
continue
# 文件读取完毕,将所有样本添加到数据集
if len(file_data) > 0:
for idx, (state, action, reward) in enumerate(file_data):
observations.append(state)
actions.append(action)
rewards.append(reward)
# 只有最后一个样本标记为Done=True
terminals.append(1.0 if idx == len(file_data) - 1 else 0.0)
total_episodes += 1
except Exception as e:
print(f"错误: 读取文件 {os.path.basename(file_path)} 失败: {e}")
continue
total_samples += file_samples
print(f" [{file_idx+1}/{len(jsonl_files)}] {os.path.basename(file_path)}: {file_samples} 样本")
if total_samples == 0:
raise ValueError("没有加载到任何有效样本")
observations = np.array(observations, dtype=np.float32)
actions = np.array(actions, dtype=np.float32)
rewards = np.array(rewards, dtype=np.float32)
terminals = np.array(terminals, dtype=np.float32)
print(f"\n加载完成:")
print(f" 总样本数: {total_samples}")
print(f" 总文件数(Episode数): {total_episodes}")
print(f" 状态维度: {observations.shape}")
print(f" 动作维度: {actions.shape}")
print(f" 奖励范围: [{rewards.min():.2f}, {rewards.max():.2f}]")
print(f" Episode终止点数: {int(terminals.sum())}")
# Reward 分布统计
print(f"\nReward 统计:")
print(f" 平均 Reward: {rewards.mean():.4f}")
print(f" Reward 标准差: {rewards.std():.4f}")
nonzero_ratio = (rewards != 0).sum() / len(rewards) * 100
print(f" 非零 Reward 比例: {nonzero_ratio:.2f}%")
if nonzero_ratio < 100:
nonzero_rewards = rewards[rewards != 0]
print(f" 非零 Reward 平均值: {nonzero_rewards.mean():.2f}")
# Reward 范围检查和警告
max_abs_reward = max(abs(rewards.min()), abs(rewards.max()))
print(f"\nReward 范围诊断:")
print(f" 最大绝对值: {max_abs_reward:.2f}")
if max_abs_reward > 100000:
print(f" ❌ 严重警告: Reward 范围过大(> 100000")
print(f" 建议: 在 C# 端缩放 reward *= {1000/max_abs_reward:.6f}")
print(f" 风险: 训练可能不稳定或失败")
elif max_abs_reward > 10000:
print(f" ⚠️ 警告: Reward 范围较大(> 10000")
print(f" 建议: 考虑缩放 reward *= {1000/max_abs_reward:.6f}")
print(f" 说明: CQL 可以处理,但缩放会更稳定")
elif max_abs_reward > 5000:
print(f" 提示: Reward 范围中等(> 5000")
print(f" 说明: CQL 可以处理,训练过程请关注 loss 曲线")
else:
print(f" ✅ Reward 范围合理(<= 5000")
print(f" 说明: 可以直接训练CQL 会自适应")
return observations, actions, rewards, terminals
# ==================== 训练函数 ====================
def train_bc(dataset):
"""训练 BC 算法"""
if not TRAIN_BC:
return None
print("\n" + "=" * 60)
print("训练 BC (Behavior Cloning)")
print("=" * 60)
# d3rlpy 2.x 使用 BCConfig 配置
from d3rlpy.algos import BCConfig
import torch
# 设置设备
device = "cuda:0" if USE_GPU and torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
if device.startswith("cuda"):
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
config = BCConfig(
batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE_BC # d3rlpy 2.x 支持此参数
)
bc = config.create(device=device)
# 自动调整训练步数
if AUTO_ADJUST_STEPS:
# 计算一个"真正的Epoch"需要多少步
dataset_size = len(dataset)
steps_per_real_epoch = max(1, dataset_size // BATCH_SIZE)
# 建议训练步数至少让每个样本被看到10次
recommended_steps = steps_per_real_epoch * 10
actual_steps = max(N_EPOCHS * 1000, recommended_steps)
print(f"\n数据集自适应训练:")
print(f" 数据集大小: {dataset_size} 样本")
print(f" 1个真实Epoch: {steps_per_real_epoch}")
print(f" 原始训练步数: {N_EPOCHS * 1000}")
print(f" 建议训练步数: {recommended_steps} (保证每个样本至少被看10次)")
print(f" 实际训练步数: {actual_steps}")
if actual_steps > N_EPOCHS * 1000:
print(f" ✅ 已自动增加训练步数以充分利用数据")
else:
actual_steps = N_EPOCHS * 1000
print(f"\n训练参数:")
print(f" 总步数: {actual_steps}")
print(f" Batch size: {BATCH_SIZE}")
print(f" 预计时间: {actual_steps / 1000 * 7 / 60:.1f} 分钟")
print("\n开始训练...")
bc.fit(
dataset,
n_steps=actual_steps,
n_steps_per_epoch=1000,
show_progress=True
)
# 保存模型
model_path = os.path.join(MODEL_DIR, "bc_model.d3")
bc.save_model(model_path)
print(f"✓ BC 模型已保存: {model_path}")
print(f" 如需导出 ONNX请运行: python export_onnx.py")
return bc
def train_iql(dataset):
"""训练 IQL 算法"""
if not TRAIN_IQL:
return None
print("\n" + "=" * 60)
print("训练 IQL (Implicit Q-Learning)")
print("=" * 60)
from d3rlpy.algos import IQLConfig
import torch
# 设置设备
device = "cuda:0" if USE_GPU and torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
if device.startswith("cuda"):
print(f"GPU: {torch.cuda.get_device_name(0)}")
config = IQLConfig(
batch_size=BATCH_SIZE,
actor_learning_rate=LEARNING_RATE_IQL,
critic_learning_rate=LEARNING_RATE_IQL
)
iql = config.create(device=device)
print(f"参数: steps={N_EPOCHS * 1000}, batch_size={BATCH_SIZE}")
print("开始训练...")
iql.fit(
dataset,
n_steps=N_EPOCHS * 1000,
n_steps_per_epoch=1000,
show_progress=True
)
# 保存模型
model_path = os.path.join(MODEL_DIR, "iql_model.d3")
iql.save_model(model_path)
print(f"✓ IQL 模型已保存: {model_path}")
print(f" 如需导出 ONNX请运行: python export_onnx.py")
return iql
class EarlyStoppingCallback:
"""早停回调 - 基于100 Epochs训练分析优化"""
def __init__(self, patience=3, delta=20.0, temp_threshold=2.5, actor_positive_stop=True):
"""
Args:
patience: 容忍多少个epoch的性能下降
delta: Actor Loss增长超过此值触发计数
temp_threshold: Temperature超过此值立即停止
actor_positive_stop: Actor Loss转正是否立即停止
"""
self.patience = patience
self.delta = delta
self.temp_threshold = temp_threshold
self.actor_positive_stop = actor_positive_stop
self.best_actor_loss = float('-inf')
self.counter = 0
self.best_epoch = 0
self.stop_reason = None
def check(self, epoch, actor_loss, temp):
"""
检查是否应该早停
Returns:
bool: True表示应该停止训练
"""
# 规则1: Actor Loss转正立即停止
if self.actor_positive_stop and actor_loss > 0:
self.stop_reason = f"Actor Loss转正 ({actor_loss:.2f})"
print(f"\n🚨 Early Stop at Epoch {epoch}: {self.stop_reason}")
print(f" 最佳模型在 Epoch {self.best_epoch} (Actor Loss = {self.best_actor_loss:.2f})")
return True
# 规则2: Temperature超过阈值立即停止
if temp > self.temp_threshold:
self.stop_reason = f"Temperature失控 ({temp:.2f} > {self.temp_threshold})"
print(f"\n🚨 Early Stop at Epoch {epoch}: {self.stop_reason}")
print(f" 最佳模型在 Epoch {self.best_epoch} (Actor Loss = {self.best_actor_loss:.2f})")
return True
# 规则3: Actor Loss连续恶化
if actor_loss > self.best_actor_loss + self.delta:
self.counter += 1
print(f"⚠️ Epoch {epoch}: Actor Loss恶化 ({actor_loss:.2f} vs 最佳 {self.best_actor_loss:.2f}), 计数 {self.counter}/{self.patience}")
if self.counter >= self.patience:
self.stop_reason = f"Actor Loss连续 {self.patience} Epochs恶化"
print(f"\n⚠️ Early Stop at Epoch {epoch}: {self.stop_reason}")
print(f" 最佳模型在 Epoch {self.best_epoch} (Actor Loss = {self.best_actor_loss:.2f})")
return True
else:
# 更新最佳值
if actor_loss < self.best_actor_loss:
self.best_actor_loss = actor_loss
self.best_epoch = epoch
print(f"✅ Epoch {epoch}: 新最佳 Actor Loss = {actor_loss:.2f} (Temp = {temp:.2f})")
self.counter = 0
return False
def train_cql(dataset):
"""训练 CQL 算法"""
if not TRAIN_CQL:
return None
print("\n" + "=" * 60)
print("训练 CQL (Conservative Q-Learning)")
print("=" * 60)
from d3rlpy.algos import CQLConfig
import torch
# 设置设备
device = "cuda:0" if USE_GPU and torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
if device.startswith("cuda"):
print(f"GPU: {torch.cuda.get_device_name(0)}")
# ✅ 优化后的 CQL 配置
# 注意d3rlpy 2.x 不支持直接设置梯度裁剪,但学习率优化仍然有效
config = CQLConfig(
batch_size=BATCH_SIZE,
actor_learning_rate=LEARNING_RATE_CQL, # 3e-4 (提高)
critic_learning_rate=LEARNING_RATE_CQL_CRITIC, # 1e-4 (适中)
alpha_learning_rate=LEARNING_RATE_CQL_ALPHA, # 5e-5 (降低)
temp_learning_rate=1e-4, # 保持默认
)
cql = config.create(device=device)
print(f"\n🔧 CQL 配置(基于 100 Epochs 完整分析优化):")
print(f" Actor Learning Rate: {LEARNING_RATE_CQL} (3e-5已验证最优)")
print(f" Critic Learning Rate: {LEARNING_RATE_CQL_CRITIC} (2e-5保持比例)")
print(f" Alpha Learning Rate: {LEARNING_RATE_CQL_ALPHA} (5e-6极慢衰减)")
print(f" 训练步数: {N_EPOCHS * 1000} ({N_EPOCHS} Epochs)")
print(f" Batch Size: {BATCH_SIZE}")
print(f" 早停机制: {'✅ 已启用' if ENABLE_EARLY_STOPPING else '❌ 未启用'}")
print("\n📊 100 Epochs 训练关键发现:")
print(" ✅ Epoch 20: Actor Loss = -124.3 (历史最佳!)")
print(" ⚠️ Epoch 25-36: Actor Loss 从 -115 → -15 (过拟合)")
print(" 💥 Epoch 37: Actor Loss = +1.1 (转正爆炸)")
print(" 💥 Epoch 100: Actor Loss = +1620 (完全失控)")
print("\n🎯 优化策略:")
print(" 1. ✅ 保持学习率 3e-5 (已验证能达到 -124.3)")
print(" 2. ✅ 训练轮数 100 → 25 (Epoch 20 是黄金点)")
print(" 3. ✅ 启用早停机制 (监控 Actor Loss 和 Temperature)")
print(" 4. ✅ Temperature > 2.5 立即停止 (Epoch 37 时 Temp=2.51)")
print(" 5. ✅ Actor Loss 转正立即停止 (防止爆炸)")
print("\n📈 学习率调优完整历史:")
print(" 尝试1: 3e-4 → Epoch 10爆炸 (Actor +2000)")
print(" 尝试2: 1e-4 → Epoch 20发散 (Actor +5.7)")
print(" 尝试3: 1.5e-4 → Epoch 28爆炸 (Actor +228)")
print(" 尝试4: 5e-5 → Epoch 28发散 (Actor +49)")
print(" 尝试5: 3e-5 → Epoch 20最优 (Actor -124.3 ⭐)")
print(" 续 → Epoch 37爆炸 (Actor +1620)")
print("\n💡 核心结论:")
print(" - 学习率 3e-5 是最优配置")
print(" - 数据集只支持 20-25 Epochs")
print(" - 必须使用早停机制防止过拟合")
print(" - 预期本次在 Epoch 20-25 自动停止")
print("\n开始训练...")
# 初始化早停机制
early_stop = None
if ENABLE_EARLY_STOPPING:
early_stop = EarlyStoppingCallback(
patience=EARLY_STOP_PATIENCE,
delta=EARLY_STOP_DELTA,
temp_threshold=EARLY_STOP_TEMP_THRESHOLD,
actor_positive_stop=EARLY_STOP_ACTOR_POSITIVE
)
print(f"\n✅ 早停机制已启用:")
print(f" - 容忍恶化: {EARLY_STOP_PATIENCE} epochs")
print(f" - 恶化阈值: Actor Loss 增长 > {EARLY_STOP_DELTA}")
print(f" - Temperature 警戒线: {EARLY_STOP_TEMP_THRESHOLD}")
print(f" - Actor Loss 转正立即停止: {EARLY_STOP_ACTOR_POSITIVE}")
print(f"\n基于 100 Epochs 训练分析:")
print(f" - Epoch 20 是黄金停止点 (Actor Loss = -124.3)")
print(f" - Epoch 28+ 开始发散")
print(f" - 预期本次训练在 Epoch 20-25 自动停止\n")
# 手动训练循环以支持早停
from d3rlpy.logging import FileAdapterFactory
import datetime
# 创建日志目录
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
log_dir = f"d3rlpy_logs/CQL_{timestamp}"
best_model_path = None
stopped_early = False
for epoch in range(1, N_EPOCHS + 1):
# 训练一个epoch
results = cql.fit(
dataset,
n_steps=1000,
n_steps_per_epoch=1000,
show_progress=True,
logger_adapter=FileAdapterFactory(root_dir=log_dir)
)
# 保存checkpoint
checkpoint_path = os.path.join(MODEL_DIR, f"cql_model_epoch{epoch}.d3")
cql.save_model(checkpoint_path)
# 检查早停
if ENABLE_EARLY_STOPPING and early_stop:
# 从结果中提取指标d3rlpy 2.x 的返回格式)
actor_loss = results.get('actor_loss', 0.0) if isinstance(results, dict) else 0.0
temp = results.get('temp', 0.0) if isinstance(results, dict) else 0.0
# 注意d3rlpy 2.x 的 fit 返回值可能不包含详细指标
# 如果无法获取,我们跳过早停检查
if actor_loss != 0.0 or temp != 0.0:
if early_stop.check(epoch, actor_loss, temp):
stopped_early = True
best_model_path = os.path.join(MODEL_DIR, f"cql_model_epoch{early_stop.best_epoch}.d3")
print(f"\n✅ 训练提前停止在 Epoch {epoch}")
print(f"✅ 最佳模型: Epoch {early_stop.best_epoch}")
print(f"✅ 停止原因: {early_stop.stop_reason}")
break
# 保存最终模型
if stopped_early and best_model_path:
# 复制最佳模型为默认模型
import shutil
final_model_path = os.path.join(MODEL_DIR, "cql_model.d3")
shutil.copy(best_model_path, final_model_path)
print(f"\n✓ 最佳模型已保存为: {final_model_path}")
print(f" (来自 Epoch {early_stop.best_epoch})")
else:
# 正常结束保存最后一个epoch的模型
model_path = os.path.join(MODEL_DIR, "cql_model.d3")
cql.save_model(model_path)
print(f"✓ CQL 模型已保存: {model_path}")
print(f" 训练日志: {log_dir}")
print(f" 如需导出 ONNX请运行: python export_onnx.py")
if stopped_early:
print(f"\n📊 训练总结:")
print(f" 计划训练: {N_EPOCHS} epochs")
print(f" 实际训练: {epoch} epochs (提前停止)")
print(f" 最佳 epoch: {early_stop.best_epoch}")
print(f" 节省时间: {(N_EPOCHS - epoch) * 26 / 60:.1f} 分钟")
return cql
# ==================== 主函数 ====================
def main():
print("""
╔════════════════════════════════════════════════════════════════╗
║ d3rlpy 训练脚本 - Unity 数据格式 ║
╚════════════════════════════════════════════════════════════════╝
数据格式:
State: 802 维 float32状态
SelectedAction: 8 维 float32归一化后的动作
Reward: float32奖励
Done: bool是否结束
训练配置:
""")
print(f" BC: {'✓ 开启' if TRAIN_BC else '✗ 关闭'}")
print(f" IQL: {'✓ 开启' if TRAIN_IQL else '✗ 关闭'}")
print(f" CQL: {'✓ 开启' if TRAIN_CQL else '✗ 关闭'}")
print()
if ENABLE_FILTERING:
print(f"数据过滤: ✅ 已启用")
print(f" - Episode 最低奖励: {MIN_EPISODE_REWARD}")
print(f" - Episode 最少步数: {MIN_EPISODE_STEPS}")
else:
print(f"数据过滤: ❌ 未启用(使用全部数据)")
print()
print(f"d3rlpy 版本: {d3rlpy.__version__}")
# 检查 GPU
import torch
gpu_available = torch.cuda.is_available()
print(f"GPU 可用: {gpu_available}")
if gpu_available:
print(f"GPU 设备: {torch.cuda.get_device_name(0)}")
print()
# 检查数据目录
if not os.path.exists(DATA_DIR):
print(f"错误: 数据目录不存在: {DATA_DIR}")
return
# 创建模型目录
os.makedirs(MODEL_DIR, exist_ok=True)
# 加载数据
try:
observations, actions, rewards, terminals = load_training_data()
except Exception as e:
print(f"\n错误: 数据加载失败")
print(f"原因: {e}")
return
# 创建数据集
print("\n" + "=" * 60)
print("创建 d3rlpy 数据集")
print("=" * 60)
dataset = MDPDataset(
observations=observations,
actions=actions,
rewards=rewards,
terminals=terminals
)
print(f"✓ 数据集创建成功")
print(f" Action type: Continuous ({ACTION_DIM}D)")
print(f" Total steps: {len(observations)}")
# 训练模型
if not (TRAIN_BC or TRAIN_IQL or TRAIN_CQL):
print("\n⚠️ 所有算法都未开启,请修改脚本顶部的开关")
return
trained_models = []
if TRAIN_BC:
bc = train_bc(dataset)
if bc:
trained_models.append("BC")
if TRAIN_IQL:
iql = train_iql(dataset)
if iql:
trained_models.append("IQL")
if TRAIN_CQL:
cql = train_cql(dataset)
if cql:
trained_models.append("CQL")
# 完成
print("\n" + "=" * 60)
print("训练完成!")
print("=" * 60)
print(f"已训练: {', '.join(trained_models)}")
print(f"模型保存位置: {os.path.abspath(MODEL_DIR)}")
print()
print("查看训练日志:")
print(f" 1. 可视化工具: python view_training_logs.py")
print(f" 2. 原始日志: d3rlpy_logs/ 目录")
print()
print("下一步:")
print(" 1. 查看训练曲线评估模型质量")
print(" 2. 导出 ONNX: python export_onnx.py")
print(" 3. 在 Unity 中使用 ONNX Runtime 加载模型")
print()
if __name__ == "__main__":
main()