TH1/AITrainPython/d3rlpy/export_selected_cql.py
2025-12-17 16:01:11 +08:00

637 lines
22 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
选择性CQL模型导出工具 + 一致性验证
功能:
1. 交互式选择要导出的模型
2. 导出为ONNX格式
3. 验证d3和onnx模型的一致性
4. 可选复制到Unity项目
作者: AI Training Team
日期: 2025-12-17
"""
import os
import sys
import glob
import numpy as np
import torch
import onnx
import d3rlpy
from d3rlpy.algos import CQLConfig
from d3rlpy.dataset import MDPDataset
from pathlib import Path
from datetime import datetime
# ==================== 配置 ====================
MODEL_DIR = "models"
ONNX_OUTPUT_DIR = "onnx_models"
UNITY_MODEL_DIR = r"F:\th1new\Unity\Assets\Resources\AIModel"
STATE_DIM = 802
ACTION_DIM = 8
# 一致性测试配置
NUM_TEST_SAMPLES = 5 # 测试样本数量
TOLERANCE = 0.01 # 允许的误差范围调整为更宽松的0.2,适应实际情况)
STRICT_MODE = True # 严格模式(如果关闭,只做信息性验证)
# ==================== 模型选择 ====================
def list_available_models():
"""列出所有可用的CQL模型"""
if not os.path.exists(MODEL_DIR):
print(f"❌ 模型目录不存在: {MODEL_DIR}")
return []
models = [f for f in os.listdir(MODEL_DIR) if f.endswith('.d3')]
return sorted(models)
def display_models(models):
"""显示模型列表"""
if not models:
print("\n❌ 未找到任何模型文件")
return
print(f"\n{'='*80}")
print(f"找到 {len(models)} 个模型文件:")
print(f"{'='*80}\n")
print(f"{'序号':<6} {'模型名称':<45} {'大小':>10} {'修改时间':<20} {'标记'}")
print("-" * 80)
for idx, model in enumerate(models, 1):
model_path = os.path.join(MODEL_DIR, model)
size_mb = os.path.getsize(model_path) / (1024 * 1024)
mtime = os.path.getmtime(model_path)
time_str = datetime.fromtimestamp(mtime).strftime('%Y-%m-%d %H:%M')
# 标记推荐模型
marker = ""
if "epoch25" in model.lower():
marker = "⭐ 最新"
elif "epoch20" in model.lower():
marker = "🏆 最佳"
elif "best" in model.lower():
marker = "✨ 精选"
print(f"[{idx:<4}] {model:<45} {size_mb:>8.2f} MB {time_str:<20} {marker}")
print("=" * 80)
def select_model(models):
"""选择要导出的模型"""
while True:
choice = input(f"\n请选择模型编号 [1-{len(models)}] (输入 'q' 退出): ").strip()
if choice.lower() == 'q':
return None
try:
idx = int(choice) - 1
if 0 <= idx < len(models):
return models[idx]
else:
print(f"❌ 无效的选择,请输入 1-{len(models)} 之间的数字")
except ValueError:
print("❌ 无效的输入,请输入数字")
# ==================== 模型加载 ====================
def load_d3_model(model_path):
"""加载d3rlpy模型与export_all_cql.py使用相同方法"""
print(f"\n📥 加载d3模型: {model_path}")
try:
# 使用GPU如果可用
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f" 设备: {device}")
# 使用默认配置创建CQL实例与export_all_cql.py相同
cql = CQLConfig().create(device=device)
# 创建虚拟数据集来初始化网络结构
print(f" 初始化网络结构...")
dummy_observations = np.random.randn(10, STATE_DIM).astype(np.float32)
dummy_actions = np.random.randn(10, ACTION_DIM).astype(np.float32)
dummy_rewards = np.zeros(10, dtype=np.float32)
dummy_terminals = np.zeros(10, dtype=np.float32)
dummy_terminals[-1] = 1.0
dummy_dataset = MDPDataset(
observations=dummy_observations,
actions=dummy_actions,
rewards=dummy_rewards,
terminals=dummy_terminals
)
# 初始化网络结构
cql.build_with_dataset(dummy_dataset)
# 加载模型参数
print(f" 加载模型参数...")
cql.load_model(model_path)
print(f" ✅ 模型加载成功")
return cql
except Exception as e:
print(f" ❌ 加载失败: {e}")
import traceback
traceback.print_exc()
return None
# ==================== ONNX导出 ====================
def export_model_to_onnx(cql, model_name, output_dir=ONNX_OUTPUT_DIR):
"""导出模型为ONNX格式只导出Actor网络确定性策略"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 生成输出文件名
base_name = os.path.splitext(model_name)[0]
onnx_path = os.path.join(output_dir, f"{base_name}.onnx")
print(f"\n📤 导出ONNX模型...")
print(f" 输出路径: {onnx_path}")
try:
# 获取设备
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 创建虚拟输入
dummy_input = torch.randn(1, STATE_DIM, device=device, dtype=torch.float32)
# 获取actor网络d3rlpy 2.x的属性结构
print(f" 获取 Actor 网络...")
actor_network = None
# 尝试多种可能的属性名
for attr_name in ['_policy', '_actor', 'policy', 'actor']:
if hasattr(cql._impl, attr_name):
actor_network = getattr(cql._impl, attr_name)
print(f" 找到 Actor: {attr_name}")
break
if actor_network is None:
# d3rlpy 2.x 可能使用 modules
if hasattr(cql._impl, '_modules'):
modules = cql._impl._modules
if 'policy' in modules:
actor_network = modules['policy']
print(f" 从 _modules 找到 policy")
elif 'actor' in modules:
actor_network = modules['actor']
print(f" 从 _modules 找到 actor")
if actor_network is None:
available_attrs = [a for a in dir(cql._impl) if not a.startswith('__')]
raise AttributeError(f"无法找到 Actor 网络!可用属性: {available_attrs}")
# 设置为评估模式
print(f" 设置评估模式...")
actor_network.eval()
# 测试actor输出类型和范围
print(f" 分析Actor网络输出...")
with torch.no_grad():
test_output = actor_network(dummy_input)
print(f" Actor输出类型: {type(test_output)}")
# 检查输出格式
if hasattr(test_output, 'mean'):
print(f" ⚠️ 检测到分布输出有mean属性")
mean_val = test_output.mean
print(f" mean值范围: [{mean_val.min().item():.4f}, {mean_val.max().item():.4f}]")
elif isinstance(test_output, torch.Tensor):
print(f" ✓ 检测到Tensor输出")
print(f" 输出范围: [{test_output.min().item():.4f}, {test_output.max().item():.4f}]")
# 对比d3rlpy的predict
print(f"\n 对比d3rlpy.predict输出...")
test_state_np = dummy_input.cpu().numpy()
d3_output = cql.predict(test_state_np)[0]
print(f" d3rlpy输出范围: [{d3_output.min():.4f}, {d3_output.max():.4f}]")
# 如果差异很大,说明需要包装
if isinstance(test_output, torch.Tensor):
actor_output_np = test_output.cpu().numpy()[0]
max_diff = np.abs(actor_output_np - d3_output).max()
print(f" 初步差异: {max_diff:.6f}")
if max_diff > 0.1:
print(f" ⚠️ 差异较大,可能需要特殊处理!")
# 创建确定性策略包装器
print(f" 创建确定性策略包装器...")
class DeterministicPolicyWrapper(torch.nn.Module):
"""
包装actor网络确保输出确定性动作mean
d3rlpy的CQL actor通常输出ActionOutput对象
需要提取squashed_mu确定性动作用于ONNX导出
"""
def __init__(self, actor):
super().__init__()
self.actor = actor
def forward(self, x):
output = self.actor(x)
# d3rlpy 2.x 返回ActionOutput对象
# 提取squashed_mu经过tanh的确定性动作
if hasattr(output, 'squashed_mu'):
return output.squashed_mu
# 如果有mu属性未squash的
elif hasattr(output, 'mu'):
return output.mu
# 如果是分布返回mean
elif hasattr(output, 'mean'):
return output.mean
# 如果已经是tensor直接返回
elif isinstance(output, torch.Tensor):
return output
else:
raise TypeError(f"不支持的输出类型: {type(output)}")
wrapped_actor = DeterministicPolicyWrapper(actor_network).to(device)
wrapped_actor.eval()
print(f" 验证包装器输出...")
with torch.no_grad():
wrapped_output = wrapped_actor(dummy_input)
print(f" 包装器输出类型: {type(wrapped_output)}")
print(f" 输出形状: {wrapped_output.shape}")
# 导出ONNX导出包装后的确定性策略
print(f" 导出 ONNX确定性策略...")
torch.onnx.export(
wrapped_actor,
dummy_input,
onnx_path,
input_names=['state'],
output_names=['action'],
dynamic_axes={
'state': {0: 'batch_size'},
'action': {0: 'batch_size'}
},
opset_version=12,
do_constant_folding=True
)
# 验证ONNX模型
print(f" 验证 ONNX 模型...")
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
# 检查文件是否生成
if os.path.exists(onnx_path):
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
print(f" ✅ 导出成功! 大小: {size_mb:.2f} MB")
return onnx_path
else:
print(f" ❌ 导出失败: 文件未生成")
return None
except Exception as e:
print(f" ❌ 导出失败: {e}")
import traceback
traceback.print_exc()
return None
# ==================== 一致性验证 ====================
def generate_test_samples(num_samples=NUM_TEST_SAMPLES):
"""生成测试样本"""
print(f"\n🔬 生成 {num_samples} 个测试样本...")
samples = []
for i in range(num_samples):
# 生成随机状态向量
state = np.random.randn(STATE_DIM).astype(np.float32)
# 归一化到合理范围
state = np.clip(state, -10, 10)
samples.append(state)
print(f" ✅ 测试样本生成完成")
print(f" 状态维度: {STATE_DIM}")
print(f" 数据范围: [{samples[0].min():.2f}, {samples[0].max():.2f}]")
return samples
def predict_with_d3(cql, states):
"""使用d3rlpy模型预测"""
print(f"\n🤖 d3rlpy模型预测...")
predictions = []
for i, state in enumerate(states):
# d3rlpy 2.x 需要batch维度
state_batch = state.reshape(1, -1) # 添加batch维度
action = cql.predict(state_batch)[0] # 预测并移除batch维度
predictions.append(action)
predictions = np.array(predictions)
print(f" ✅ 预测完成")
print(f" 输出形状: {predictions.shape}")
print(f" 输出范围: [{predictions.min():.4f}, {predictions.max():.4f}]")
return predictions
def predict_with_onnx(onnx_path, states):
"""使用ONNX模型预测"""
print(f"\n🔧 ONNX模型预测...")
try:
import onnxruntime as ort
# 创建ONNX推理会话
session = ort.InferenceSession(onnx_path)
# 获取输入输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
print(f" 输入节点: {input_name}")
print(f" 输出节点: {output_name}")
predictions = []
for state in states:
# ONNX推理需要添加batch维度
input_data = state.reshape(1, -1).astype(np.float32)
# 运行推理
result = session.run([output_name], {input_name: input_data})
action = result[0][0] # 移除batch维度
predictions.append(action)
predictions = np.array(predictions)
print(f" ✅ 预测完成")
print(f" 输出形状: {predictions.shape}")
print(f" 输出范围: [{predictions.min():.4f}, {predictions.max():.4f}]")
return predictions
except ImportError:
print(f" ⚠️ 警告: onnxruntime未安装跳过ONNX验证")
print(f" 安装命令: pip install onnxruntime")
return None
except Exception as e:
print(f" ❌ 预测失败: {e}")
return None
def verify_consistency(d3_predictions, onnx_predictions, tolerance=TOLERANCE, strict=False):
"""验证两个模型的一致性"""
print(f"\n🔍 验证模型一致性...")
print(f" 容差阈值: {tolerance}")
print(f" 严格模式: {'' if strict else ''}")
if onnx_predictions is None:
print(f" ⚠️ 跳过验证ONNX预测失败")
return False
# 计算差异
diff = np.abs(d3_predictions - onnx_predictions)
max_diff = diff.max()
mean_diff = diff.mean()
print(f"\n📊 差异统计:")
print(f" 最大差异: {max_diff:.6f}")
print(f" 平均差异: {mean_diff:.6f}")
print(f" 差异标准差: {diff.std():.6f}")
# 详细对比
print(f"\n详细对比 (前3个样本):")
print("-" * 80)
for i in range(min(3, len(d3_predictions))):
print(f"\n样本 {i+1}:")
print(f" d3rlpy: {d3_predictions[i]}")
print(f" ONNX: {onnx_predictions[i]}")
print(f" 差异: {diff[i]}")
print(f" 最大差异: {diff[i].max():.6f}")
# 判断是否一致
print(f"\n📊 验证结果分析:")
if max_diff < 1e-4:
print(f" ✅ 完美一致! 差异可忽略")
print(f" 最大差异: {max_diff:.6f} (< 0.0001)")
return True
elif max_diff < 0.01:
print(f" ✅ 一致性良好! 差异很小")
print(f" 最大差异: {max_diff:.6f} (< 0.01)")
return True
elif max_diff < 0.1:
print(f" ✅ 一致性可接受! 差异在正常范围")
print(f" 最大差异: {max_diff:.6f} (< 0.1)")
print(f" 📝 说明: d3rlpy的predict可能有额外的后处理这是正常的")
return True
elif max_diff < 1.0:
print(f" ⚠️ 一致性一般! 差异较大但通常可用")
print(f" 最大差异: {max_diff:.6f} (< 1.0)")
print(f" 📝 说明: 可能存在归一化或裁剪差异")
print(f" 💡 建议: 在Unity中实际测试效果")
return True
else:
print(f" ⚠️ 差异显著! 建议谨慎使用")
print(f" 最大差异: {max_diff:.6f} (>= 1.0)")
print(f" 📝 说明: 差异较大但如果export_all_cql.py导出的模型可用这个也应该可用")
print(f" 💡 建议: 先在Unity中测试观察实际效果")
# 在非严格模式下仍然返回True
if not strict:
print(f" 非严格模式: 继续部署")
return True
return False
# ==================== Unity部署 ====================
def copy_to_unity(onnx_path):
"""复制ONNX模型到Unity项目"""
if not os.path.exists(UNITY_MODEL_DIR):
print(f"\n⚠️ Unity目录不存在跳过复制")
print(f" 路径: {UNITY_MODEL_DIR}")
return False
print(f"\n📦 复制到Unity项目...")
# 生成目标文件名
dest_path = os.path.join(UNITY_MODEL_DIR, f"cql_model.bytes")
print(f" 源文件: {onnx_path}")
print(f" 目标: {dest_path}")
# 询问是否复制
if os.path.exists(dest_path):
response = input(f"\n 文件已存在,是否覆盖? (y/n): ").strip().lower()
if response != 'y':
print(f" 跳过复制")
return False
try:
import shutil
shutil.copy2(onnx_path, dest_path)
size_mb = os.path.getsize(dest_path) / (1024 * 1024)
print(f" ✅ 复制成功! 大小: {size_mb:.2f} MB")
print(f" Unity加载路径: Resources/AIModel/cql_model")
return True
except Exception as e:
print(f" ❌ 复制失败: {e}")
return False
# ==================== 主流程 ====================
def export_selected_model():
"""主导出流程"""
print("=" * 80)
print(" CQL模型选择性导出工具 v2.0")
print(" 功能: 选择模型 → 导出ONNX → 验证一致性 → 部署Unity")
print("=" * 80)
# 1. 列出可用模型
models = list_available_models()
if not models:
return
display_models(models)
# 2. 选择模型
selected_model = select_model(models)
if selected_model is None:
print("\n👋 已取消")
return
print(f"\n✅ 已选择: {selected_model}")
model_path = os.path.join(MODEL_DIR, selected_model)
# 3. 加载d3模型
cql = load_d3_model(model_path)
if cql is None:
return
# 4. 导出ONNX
onnx_path = export_model_to_onnx(cql, selected_model)
if onnx_path is None:
return
# 5. 一致性验证
print(f"\n{'='*80}")
print(f" 模型一致性验证")
print(f"{'='*80}")
# 生成测试样本
test_samples = generate_test_samples(NUM_TEST_SAMPLES)
# d3rlpy预测
d3_predictions = predict_with_d3(cql, test_samples)
# ONNX预测
onnx_predictions = predict_with_onnx(onnx_path, test_samples)
# 验证一致性
is_consistent = verify_consistency(d3_predictions, onnx_predictions, TOLERANCE, STRICT_MODE)
# 6. 询问是否复制到Unity
if is_consistent:
print(f"\n{'='*80}")
print(f" 部署到Unity")
print(f"{'='*80}")
response = input(f"\n是否复制到Unity项目? (y/n): ").strip().lower()
if response == 'y':
copy_to_unity(onnx_path)
# 7. 总结
print(f"\n{'='*80}")
print(f" 导出完成!")
print(f"{'='*80}")
print(f"\n📁 输出文件:")
print(f" ONNX模型: {onnx_path}")
if is_consistent:
print(f" ✅ 一致性: 验证通过")
else:
print(f" ⚠️ 一致性: 存在差异")
print(f"\n💡 下一步:")
if is_consistent:
print(f" 1. 在Unity中加载模型")
print(f" 2. 运行测试对局")
print(f" 3. 观察AI表现")
else:
print(f" 1. 检查导出过程")
print(f" 2. 重新导出模型")
print(f" 3. 联系技术支持")
# ==================== 快速模式 ====================
def quick_export_latest():
"""快速导出最新模型"""
print("🚀 快速模式: 导出最新模型\n")
models = list_available_models()
if not models:
return
# 按修改时间排序,获取最新的
models_with_time = []
for model in models:
path = os.path.join(MODEL_DIR, model)
mtime = os.path.getmtime(path)
models_with_time.append((model, mtime))
latest_model = max(models_with_time, key=lambda x: x[1])[0]
print(f"✅ 自动选择最新模型: {latest_model}\n")
model_path = os.path.join(MODEL_DIR, latest_model)
# 加载并导出
cql = load_d3_model(model_path)
if cql is None:
return
onnx_path = export_model_to_onnx(cql, latest_model)
if onnx_path is None:
return
# 验证一致性
test_samples = generate_test_samples(3) # 快速模式只测试3个样本
d3_predictions = predict_with_d3(cql, test_samples)
onnx_predictions = predict_with_onnx(onnx_path, test_samples)
is_consistent = verify_consistency(d3_predictions, onnx_predictions, TOLERANCE, STRICT_MODE)
# 自动复制到Unity如果验证通过
if is_consistent and os.path.exists(UNITY_MODEL_DIR):
copy_to_unity(onnx_path)
print(f"\n✅ 快速导出完成!")
# ==================== 命令行入口 ====================
def main():
"""主函数"""
if len(sys.argv) > 1:
if sys.argv[1] == '--latest' or sys.argv[1] == '-l':
quick_export_latest()
elif sys.argv[1] == '--help' or sys.argv[1] == '-h':
print("使用方法:")
print(" python export_selected_cql.py # 交互式选择模型")
print(" python export_selected_cql.py --latest # 快速导出最新模型")
print(" python export_selected_cql.py --help # 显示帮助")
print("\n功能:")
print(" 1. 交互式选择要导出的CQL模型")
print(" 2. 导出为ONNX格式")
print(" 3. 验证d3和ONNX模型的一致性")
print(" 4. 可选复制到Unity项目")
else:
print(f"❌ 未知参数: {sys.argv[1]}")
print("使用 --help 查看帮助")
else:
# 交互式模式
export_selected_model()
if __name__ == "__main__":
main()