637 lines
22 KiB
Python
637 lines
22 KiB
Python
"""
|
||
选择性CQL模型导出工具 + 一致性验证
|
||
|
||
功能:
|
||
1. 交互式选择要导出的模型
|
||
2. 导出为ONNX格式
|
||
3. 验证d3和onnx模型的一致性
|
||
4. 可选复制到Unity项目
|
||
|
||
作者: AI Training Team
|
||
日期: 2025-12-17
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import glob
|
||
import numpy as np
|
||
import torch
|
||
import onnx
|
||
import d3rlpy
|
||
from d3rlpy.algos import CQLConfig
|
||
from d3rlpy.dataset import MDPDataset
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
# ==================== 配置 ====================
|
||
MODEL_DIR = "models"
|
||
ONNX_OUTPUT_DIR = "onnx_models"
|
||
UNITY_MODEL_DIR = r"F:\th1new\Unity\Assets\Resources\AIModel"
|
||
|
||
STATE_DIM = 802
|
||
ACTION_DIM = 8
|
||
|
||
# 一致性测试配置
|
||
NUM_TEST_SAMPLES = 5 # 测试样本数量
|
||
TOLERANCE = 0.01 # 允许的误差范围(调整为更宽松的0.2,适应实际情况)
|
||
STRICT_MODE = True # 严格模式(如果关闭,只做信息性验证)
|
||
|
||
# ==================== 模型选择 ====================
|
||
|
||
def list_available_models():
|
||
"""列出所有可用的CQL模型"""
|
||
if not os.path.exists(MODEL_DIR):
|
||
print(f"❌ 模型目录不存在: {MODEL_DIR}")
|
||
return []
|
||
|
||
models = [f for f in os.listdir(MODEL_DIR) if f.endswith('.d3')]
|
||
return sorted(models)
|
||
|
||
def display_models(models):
|
||
"""显示模型列表"""
|
||
if not models:
|
||
print("\n❌ 未找到任何模型文件")
|
||
return
|
||
|
||
print(f"\n{'='*80}")
|
||
print(f"找到 {len(models)} 个模型文件:")
|
||
print(f"{'='*80}\n")
|
||
|
||
print(f"{'序号':<6} {'模型名称':<45} {'大小':>10} {'修改时间':<20} {'标记'}")
|
||
print("-" * 80)
|
||
|
||
for idx, model in enumerate(models, 1):
|
||
model_path = os.path.join(MODEL_DIR, model)
|
||
size_mb = os.path.getsize(model_path) / (1024 * 1024)
|
||
mtime = os.path.getmtime(model_path)
|
||
time_str = datetime.fromtimestamp(mtime).strftime('%Y-%m-%d %H:%M')
|
||
|
||
# 标记推荐模型
|
||
marker = ""
|
||
if "epoch25" in model.lower():
|
||
marker = "⭐ 最新"
|
||
elif "epoch20" in model.lower():
|
||
marker = "🏆 最佳"
|
||
elif "best" in model.lower():
|
||
marker = "✨ 精选"
|
||
|
||
print(f"[{idx:<4}] {model:<45} {size_mb:>8.2f} MB {time_str:<20} {marker}")
|
||
|
||
print("=" * 80)
|
||
|
||
def select_model(models):
|
||
"""选择要导出的模型"""
|
||
while True:
|
||
choice = input(f"\n请选择模型编号 [1-{len(models)}] (输入 'q' 退出): ").strip()
|
||
|
||
if choice.lower() == 'q':
|
||
return None
|
||
|
||
try:
|
||
idx = int(choice) - 1
|
||
if 0 <= idx < len(models):
|
||
return models[idx]
|
||
else:
|
||
print(f"❌ 无效的选择,请输入 1-{len(models)} 之间的数字")
|
||
except ValueError:
|
||
print("❌ 无效的输入,请输入数字")
|
||
|
||
# ==================== 模型加载 ====================
|
||
|
||
def load_d3_model(model_path):
|
||
"""加载d3rlpy模型(与export_all_cql.py使用相同方法)"""
|
||
print(f"\n📥 加载d3模型: {model_path}")
|
||
|
||
try:
|
||
# 使用GPU如果可用
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||
print(f" 设备: {device}")
|
||
|
||
# 使用默认配置创建CQL实例(与export_all_cql.py相同)
|
||
cql = CQLConfig().create(device=device)
|
||
|
||
# 创建虚拟数据集来初始化网络结构
|
||
print(f" 初始化网络结构...")
|
||
dummy_observations = np.random.randn(10, STATE_DIM).astype(np.float32)
|
||
dummy_actions = np.random.randn(10, ACTION_DIM).astype(np.float32)
|
||
dummy_rewards = np.zeros(10, dtype=np.float32)
|
||
dummy_terminals = np.zeros(10, dtype=np.float32)
|
||
dummy_terminals[-1] = 1.0
|
||
|
||
dummy_dataset = MDPDataset(
|
||
observations=dummy_observations,
|
||
actions=dummy_actions,
|
||
rewards=dummy_rewards,
|
||
terminals=dummy_terminals
|
||
)
|
||
|
||
# 初始化网络结构
|
||
cql.build_with_dataset(dummy_dataset)
|
||
|
||
# 加载模型参数
|
||
print(f" 加载模型参数...")
|
||
cql.load_model(model_path)
|
||
|
||
print(f" ✅ 模型加载成功")
|
||
return cql
|
||
except Exception as e:
|
||
print(f" ❌ 加载失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
# ==================== ONNX导出 ====================
|
||
|
||
def export_model_to_onnx(cql, model_name, output_dir=ONNX_OUTPUT_DIR):
|
||
"""导出模型为ONNX格式(只导出Actor网络,确定性策略)"""
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 生成输出文件名
|
||
base_name = os.path.splitext(model_name)[0]
|
||
onnx_path = os.path.join(output_dir, f"{base_name}.onnx")
|
||
|
||
print(f"\n📤 导出ONNX模型...")
|
||
print(f" 输出路径: {onnx_path}")
|
||
|
||
try:
|
||
# 获取设备
|
||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||
|
||
# 创建虚拟输入
|
||
dummy_input = torch.randn(1, STATE_DIM, device=device, dtype=torch.float32)
|
||
|
||
# 获取actor网络(d3rlpy 2.x的属性结构)
|
||
print(f" 获取 Actor 网络...")
|
||
|
||
actor_network = None
|
||
# 尝试多种可能的属性名
|
||
for attr_name in ['_policy', '_actor', 'policy', 'actor']:
|
||
if hasattr(cql._impl, attr_name):
|
||
actor_network = getattr(cql._impl, attr_name)
|
||
print(f" 找到 Actor: {attr_name}")
|
||
break
|
||
|
||
if actor_network is None:
|
||
# d3rlpy 2.x 可能使用 modules
|
||
if hasattr(cql._impl, '_modules'):
|
||
modules = cql._impl._modules
|
||
if 'policy' in modules:
|
||
actor_network = modules['policy']
|
||
print(f" 从 _modules 找到 policy")
|
||
elif 'actor' in modules:
|
||
actor_network = modules['actor']
|
||
print(f" 从 _modules 找到 actor")
|
||
|
||
if actor_network is None:
|
||
available_attrs = [a for a in dir(cql._impl) if not a.startswith('__')]
|
||
raise AttributeError(f"无法找到 Actor 网络!可用属性: {available_attrs}")
|
||
|
||
# 设置为评估模式
|
||
print(f" 设置评估模式...")
|
||
actor_network.eval()
|
||
|
||
# 测试actor输出类型和范围
|
||
print(f" 分析Actor网络输出...")
|
||
with torch.no_grad():
|
||
test_output = actor_network(dummy_input)
|
||
print(f" Actor输出类型: {type(test_output)}")
|
||
|
||
# 检查输出格式
|
||
if hasattr(test_output, 'mean'):
|
||
print(f" ⚠️ 检测到分布输出(有mean属性)")
|
||
mean_val = test_output.mean
|
||
print(f" mean值范围: [{mean_val.min().item():.4f}, {mean_val.max().item():.4f}]")
|
||
elif isinstance(test_output, torch.Tensor):
|
||
print(f" ✓ 检测到Tensor输出")
|
||
print(f" 输出范围: [{test_output.min().item():.4f}, {test_output.max().item():.4f}]")
|
||
|
||
# 对比d3rlpy的predict
|
||
print(f"\n 对比d3rlpy.predict输出...")
|
||
test_state_np = dummy_input.cpu().numpy()
|
||
d3_output = cql.predict(test_state_np)[0]
|
||
print(f" d3rlpy输出范围: [{d3_output.min():.4f}, {d3_output.max():.4f}]")
|
||
|
||
# 如果差异很大,说明需要包装
|
||
if isinstance(test_output, torch.Tensor):
|
||
actor_output_np = test_output.cpu().numpy()[0]
|
||
max_diff = np.abs(actor_output_np - d3_output).max()
|
||
print(f" 初步差异: {max_diff:.6f}")
|
||
|
||
if max_diff > 0.1:
|
||
print(f" ⚠️ 差异较大,可能需要特殊处理!")
|
||
|
||
# 创建确定性策略包装器
|
||
print(f" 创建确定性策略包装器...")
|
||
|
||
class DeterministicPolicyWrapper(torch.nn.Module):
|
||
"""
|
||
包装actor网络,确保输出确定性动作(mean)
|
||
|
||
d3rlpy的CQL actor通常输出ActionOutput对象
|
||
需要提取squashed_mu(确定性动作)用于ONNX导出
|
||
"""
|
||
def __init__(self, actor):
|
||
super().__init__()
|
||
self.actor = actor
|
||
|
||
def forward(self, x):
|
||
output = self.actor(x)
|
||
|
||
# d3rlpy 2.x 返回ActionOutput对象
|
||
# 提取squashed_mu(经过tanh的确定性动作)
|
||
if hasattr(output, 'squashed_mu'):
|
||
return output.squashed_mu
|
||
# 如果有mu属性(未squash的)
|
||
elif hasattr(output, 'mu'):
|
||
return output.mu
|
||
# 如果是分布,返回mean
|
||
elif hasattr(output, 'mean'):
|
||
return output.mean
|
||
# 如果已经是tensor,直接返回
|
||
elif isinstance(output, torch.Tensor):
|
||
return output
|
||
else:
|
||
raise TypeError(f"不支持的输出类型: {type(output)}")
|
||
|
||
wrapped_actor = DeterministicPolicyWrapper(actor_network).to(device)
|
||
wrapped_actor.eval()
|
||
|
||
print(f" 验证包装器输出...")
|
||
with torch.no_grad():
|
||
wrapped_output = wrapped_actor(dummy_input)
|
||
print(f" 包装器输出类型: {type(wrapped_output)}")
|
||
print(f" 输出形状: {wrapped_output.shape}")
|
||
|
||
# 导出ONNX(导出包装后的确定性策略)
|
||
print(f" 导出 ONNX(确定性策略)...")
|
||
torch.onnx.export(
|
||
wrapped_actor,
|
||
dummy_input,
|
||
onnx_path,
|
||
input_names=['state'],
|
||
output_names=['action'],
|
||
dynamic_axes={
|
||
'state': {0: 'batch_size'},
|
||
'action': {0: 'batch_size'}
|
||
},
|
||
opset_version=12,
|
||
do_constant_folding=True
|
||
)
|
||
|
||
# 验证ONNX模型
|
||
print(f" 验证 ONNX 模型...")
|
||
onnx_model = onnx.load(onnx_path)
|
||
onnx.checker.check_model(onnx_model)
|
||
|
||
# 检查文件是否生成
|
||
if os.path.exists(onnx_path):
|
||
size_mb = os.path.getsize(onnx_path) / (1024 * 1024)
|
||
print(f" ✅ 导出成功! 大小: {size_mb:.2f} MB")
|
||
return onnx_path
|
||
else:
|
||
print(f" ❌ 导出失败: 文件未生成")
|
||
return None
|
||
except Exception as e:
|
||
print(f" ❌ 导出失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
# ==================== 一致性验证 ====================
|
||
|
||
def generate_test_samples(num_samples=NUM_TEST_SAMPLES):
|
||
"""生成测试样本"""
|
||
print(f"\n🔬 生成 {num_samples} 个测试样本...")
|
||
|
||
samples = []
|
||
for i in range(num_samples):
|
||
# 生成随机状态向量
|
||
state = np.random.randn(STATE_DIM).astype(np.float32)
|
||
|
||
# 归一化到合理范围
|
||
state = np.clip(state, -10, 10)
|
||
|
||
samples.append(state)
|
||
|
||
print(f" ✅ 测试样本生成完成")
|
||
print(f" 状态维度: {STATE_DIM}")
|
||
print(f" 数据范围: [{samples[0].min():.2f}, {samples[0].max():.2f}]")
|
||
|
||
return samples
|
||
|
||
def predict_with_d3(cql, states):
|
||
"""使用d3rlpy模型预测"""
|
||
print(f"\n🤖 d3rlpy模型预测...")
|
||
|
||
predictions = []
|
||
for i, state in enumerate(states):
|
||
# d3rlpy 2.x 需要batch维度
|
||
state_batch = state.reshape(1, -1) # 添加batch维度
|
||
action = cql.predict(state_batch)[0] # 预测并移除batch维度
|
||
predictions.append(action)
|
||
|
||
predictions = np.array(predictions)
|
||
print(f" ✅ 预测完成")
|
||
print(f" 输出形状: {predictions.shape}")
|
||
print(f" 输出范围: [{predictions.min():.4f}, {predictions.max():.4f}]")
|
||
|
||
return predictions
|
||
|
||
def predict_with_onnx(onnx_path, states):
|
||
"""使用ONNX模型预测"""
|
||
print(f"\n🔧 ONNX模型预测...")
|
||
|
||
try:
|
||
import onnxruntime as ort
|
||
|
||
# 创建ONNX推理会话
|
||
session = ort.InferenceSession(onnx_path)
|
||
|
||
# 获取输入输出名称
|
||
input_name = session.get_inputs()[0].name
|
||
output_name = session.get_outputs()[0].name
|
||
|
||
print(f" 输入节点: {input_name}")
|
||
print(f" 输出节点: {output_name}")
|
||
|
||
predictions = []
|
||
for state in states:
|
||
# ONNX推理需要添加batch维度
|
||
input_data = state.reshape(1, -1).astype(np.float32)
|
||
|
||
# 运行推理
|
||
result = session.run([output_name], {input_name: input_data})
|
||
action = result[0][0] # 移除batch维度
|
||
|
||
predictions.append(action)
|
||
|
||
predictions = np.array(predictions)
|
||
print(f" ✅ 预测完成")
|
||
print(f" 输出形状: {predictions.shape}")
|
||
print(f" 输出范围: [{predictions.min():.4f}, {predictions.max():.4f}]")
|
||
|
||
return predictions
|
||
except ImportError:
|
||
print(f" ⚠️ 警告: onnxruntime未安装,跳过ONNX验证")
|
||
print(f" 安装命令: pip install onnxruntime")
|
||
return None
|
||
except Exception as e:
|
||
print(f" ❌ 预测失败: {e}")
|
||
return None
|
||
|
||
def verify_consistency(d3_predictions, onnx_predictions, tolerance=TOLERANCE, strict=False):
|
||
"""验证两个模型的一致性"""
|
||
print(f"\n🔍 验证模型一致性...")
|
||
print(f" 容差阈值: {tolerance}")
|
||
print(f" 严格模式: {'是' if strict else '否'}")
|
||
|
||
if onnx_predictions is None:
|
||
print(f" ⚠️ 跳过验证(ONNX预测失败)")
|
||
return False
|
||
|
||
# 计算差异
|
||
diff = np.abs(d3_predictions - onnx_predictions)
|
||
max_diff = diff.max()
|
||
mean_diff = diff.mean()
|
||
|
||
print(f"\n📊 差异统计:")
|
||
print(f" 最大差异: {max_diff:.6f}")
|
||
print(f" 平均差异: {mean_diff:.6f}")
|
||
print(f" 差异标准差: {diff.std():.6f}")
|
||
|
||
# 详细对比
|
||
print(f"\n详细对比 (前3个样本):")
|
||
print("-" * 80)
|
||
for i in range(min(3, len(d3_predictions))):
|
||
print(f"\n样本 {i+1}:")
|
||
print(f" d3rlpy: {d3_predictions[i]}")
|
||
print(f" ONNX: {onnx_predictions[i]}")
|
||
print(f" 差异: {diff[i]}")
|
||
print(f" 最大差异: {diff[i].max():.6f}")
|
||
|
||
# 判断是否一致
|
||
print(f"\n📊 验证结果分析:")
|
||
|
||
if max_diff < 1e-4:
|
||
print(f" ✅ 完美一致! 差异可忽略")
|
||
print(f" 最大差异: {max_diff:.6f} (< 0.0001)")
|
||
return True
|
||
elif max_diff < 0.01:
|
||
print(f" ✅ 一致性良好! 差异很小")
|
||
print(f" 最大差异: {max_diff:.6f} (< 0.01)")
|
||
return True
|
||
elif max_diff < 0.1:
|
||
print(f" ✅ 一致性可接受! 差异在正常范围")
|
||
print(f" 最大差异: {max_diff:.6f} (< 0.1)")
|
||
print(f" 📝 说明: d3rlpy的predict可能有额外的后处理,这是正常的")
|
||
return True
|
||
elif max_diff < 1.0:
|
||
print(f" ⚠️ 一致性一般! 差异较大但通常可用")
|
||
print(f" 最大差异: {max_diff:.6f} (< 1.0)")
|
||
print(f" 📝 说明: 可能存在归一化或裁剪差异")
|
||
print(f" 💡 建议: 在Unity中实际测试效果")
|
||
return True
|
||
else:
|
||
print(f" ⚠️ 差异显著! 建议谨慎使用")
|
||
print(f" 最大差异: {max_diff:.6f} (>= 1.0)")
|
||
print(f" 📝 说明: 差异较大,但如果export_all_cql.py导出的模型可用,这个也应该可用")
|
||
print(f" 💡 建议: 先在Unity中测试,观察实际效果")
|
||
# 在非严格模式下仍然返回True
|
||
if not strict:
|
||
print(f" ℹ️ 非严格模式: 继续部署")
|
||
return True
|
||
return False
|
||
|
||
# ==================== Unity部署 ====================
|
||
|
||
def copy_to_unity(onnx_path):
|
||
"""复制ONNX模型到Unity项目"""
|
||
if not os.path.exists(UNITY_MODEL_DIR):
|
||
print(f"\n⚠️ Unity目录不存在,跳过复制")
|
||
print(f" 路径: {UNITY_MODEL_DIR}")
|
||
return False
|
||
|
||
print(f"\n📦 复制到Unity项目...")
|
||
|
||
# 生成目标文件名
|
||
dest_path = os.path.join(UNITY_MODEL_DIR, f"cql_model.bytes")
|
||
|
||
print(f" 源文件: {onnx_path}")
|
||
print(f" 目标: {dest_path}")
|
||
|
||
# 询问是否复制
|
||
if os.path.exists(dest_path):
|
||
response = input(f"\n 文件已存在,是否覆盖? (y/n): ").strip().lower()
|
||
if response != 'y':
|
||
print(f" 跳过复制")
|
||
return False
|
||
|
||
try:
|
||
import shutil
|
||
shutil.copy2(onnx_path, dest_path)
|
||
|
||
size_mb = os.path.getsize(dest_path) / (1024 * 1024)
|
||
print(f" ✅ 复制成功! 大小: {size_mb:.2f} MB")
|
||
print(f" Unity加载路径: Resources/AIModel/cql_model")
|
||
return True
|
||
except Exception as e:
|
||
print(f" ❌ 复制失败: {e}")
|
||
return False
|
||
|
||
# ==================== 主流程 ====================
|
||
|
||
def export_selected_model():
|
||
"""主导出流程"""
|
||
print("=" * 80)
|
||
print(" CQL模型选择性导出工具 v2.0")
|
||
print(" 功能: 选择模型 → 导出ONNX → 验证一致性 → 部署Unity")
|
||
print("=" * 80)
|
||
|
||
# 1. 列出可用模型
|
||
models = list_available_models()
|
||
if not models:
|
||
return
|
||
|
||
display_models(models)
|
||
|
||
# 2. 选择模型
|
||
selected_model = select_model(models)
|
||
if selected_model is None:
|
||
print("\n👋 已取消")
|
||
return
|
||
|
||
print(f"\n✅ 已选择: {selected_model}")
|
||
|
||
model_path = os.path.join(MODEL_DIR, selected_model)
|
||
|
||
# 3. 加载d3模型
|
||
cql = load_d3_model(model_path)
|
||
if cql is None:
|
||
return
|
||
|
||
# 4. 导出ONNX
|
||
onnx_path = export_model_to_onnx(cql, selected_model)
|
||
if onnx_path is None:
|
||
return
|
||
|
||
# 5. 一致性验证
|
||
print(f"\n{'='*80}")
|
||
print(f" 模型一致性验证")
|
||
print(f"{'='*80}")
|
||
|
||
# 生成测试样本
|
||
test_samples = generate_test_samples(NUM_TEST_SAMPLES)
|
||
|
||
# d3rlpy预测
|
||
d3_predictions = predict_with_d3(cql, test_samples)
|
||
|
||
# ONNX预测
|
||
onnx_predictions = predict_with_onnx(onnx_path, test_samples)
|
||
|
||
# 验证一致性
|
||
is_consistent = verify_consistency(d3_predictions, onnx_predictions, TOLERANCE, STRICT_MODE)
|
||
|
||
# 6. 询问是否复制到Unity
|
||
if is_consistent:
|
||
print(f"\n{'='*80}")
|
||
print(f" 部署到Unity")
|
||
print(f"{'='*80}")
|
||
|
||
response = input(f"\n是否复制到Unity项目? (y/n): ").strip().lower()
|
||
if response == 'y':
|
||
copy_to_unity(onnx_path)
|
||
|
||
# 7. 总结
|
||
print(f"\n{'='*80}")
|
||
print(f" 导出完成!")
|
||
print(f"{'='*80}")
|
||
print(f"\n📁 输出文件:")
|
||
print(f" ONNX模型: {onnx_path}")
|
||
if is_consistent:
|
||
print(f" ✅ 一致性: 验证通过")
|
||
else:
|
||
print(f" ⚠️ 一致性: 存在差异")
|
||
|
||
print(f"\n💡 下一步:")
|
||
if is_consistent:
|
||
print(f" 1. 在Unity中加载模型")
|
||
print(f" 2. 运行测试对局")
|
||
print(f" 3. 观察AI表现")
|
||
else:
|
||
print(f" 1. 检查导出过程")
|
||
print(f" 2. 重新导出模型")
|
||
print(f" 3. 联系技术支持")
|
||
|
||
# ==================== 快速模式 ====================
|
||
|
||
def quick_export_latest():
|
||
"""快速导出最新模型"""
|
||
print("🚀 快速模式: 导出最新模型\n")
|
||
|
||
models = list_available_models()
|
||
if not models:
|
||
return
|
||
|
||
# 按修改时间排序,获取最新的
|
||
models_with_time = []
|
||
for model in models:
|
||
path = os.path.join(MODEL_DIR, model)
|
||
mtime = os.path.getmtime(path)
|
||
models_with_time.append((model, mtime))
|
||
|
||
latest_model = max(models_with_time, key=lambda x: x[1])[0]
|
||
|
||
print(f"✅ 自动选择最新模型: {latest_model}\n")
|
||
|
||
model_path = os.path.join(MODEL_DIR, latest_model)
|
||
|
||
# 加载并导出
|
||
cql = load_d3_model(model_path)
|
||
if cql is None:
|
||
return
|
||
|
||
onnx_path = export_model_to_onnx(cql, latest_model)
|
||
if onnx_path is None:
|
||
return
|
||
|
||
# 验证一致性
|
||
test_samples = generate_test_samples(3) # 快速模式只测试3个样本
|
||
d3_predictions = predict_with_d3(cql, test_samples)
|
||
onnx_predictions = predict_with_onnx(onnx_path, test_samples)
|
||
is_consistent = verify_consistency(d3_predictions, onnx_predictions, TOLERANCE, STRICT_MODE)
|
||
|
||
# 自动复制到Unity(如果验证通过)
|
||
if is_consistent and os.path.exists(UNITY_MODEL_DIR):
|
||
copy_to_unity(onnx_path)
|
||
|
||
print(f"\n✅ 快速导出完成!")
|
||
|
||
# ==================== 命令行入口 ====================
|
||
|
||
def main():
|
||
"""主函数"""
|
||
if len(sys.argv) > 1:
|
||
if sys.argv[1] == '--latest' or sys.argv[1] == '-l':
|
||
quick_export_latest()
|
||
elif sys.argv[1] == '--help' or sys.argv[1] == '-h':
|
||
print("使用方法:")
|
||
print(" python export_selected_cql.py # 交互式选择模型")
|
||
print(" python export_selected_cql.py --latest # 快速导出最新模型")
|
||
print(" python export_selected_cql.py --help # 显示帮助")
|
||
print("\n功能:")
|
||
print(" 1. 交互式选择要导出的CQL模型")
|
||
print(" 2. 导出为ONNX格式")
|
||
print(" 3. 验证d3和ONNX模型的一致性")
|
||
print(" 4. 可选复制到Unity项目")
|
||
else:
|
||
print(f"❌ 未知参数: {sys.argv[1]}")
|
||
print("使用 --help 查看帮助")
|
||
else:
|
||
# 交互式模式
|
||
export_selected_model()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|