14 lines
340 B
Python
14 lines
340 B
Python
import onnxruntime as ort
|
|
|
|
model_path = 'models/ppo_model_epoch_20.onnx'
|
|
session = ort.InferenceSession(model_path)
|
|
|
|
print("模型输入:")
|
|
for inp in session.get_inputs():
|
|
print(f" {inp.name}: {inp.shape} ({inp.type})")
|
|
|
|
print("\n模型输出:")
|
|
for out in session.get_outputs():
|
|
print(f" {out.name}: {out.shape} ({out.type})")
|
|
|