之前用torchvision里面的官方模型做了脚本的测试,今天手贱,没忍住,用onnx官方的models做了验证,于是带来了新的血泪。
onnx/modelsgist更新了脚本,增加了验证代码的文件onnx2pytorch_validate.py:
https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
脚本用血泪更新后,基本起码可以支持以下onnx/models中的几个模型了:
重建模型后误差较小的特征:
满足下面两点的模型误差会显著减少:
当然,还有一个更大的可能,我某个转换代码有问题。
误差测试方法:
使用MXNet导入onnx VS 脚本导入模型,输出以下内容:
print(abs_error.max(), abs_error.mean(), abs_error.min())
print(mo[0][:5]) # mxnet raw output
print(o[0][:5]) # depsmodule raw output
运行结果:
desenet121的结果没有flatten,太恶心了,不要怪我,转出来的模型就这样。
================================================================================
googlenet.onnx :
Input Blob Names: ['data_0']
<Symbol softmax0>
0.038564444 0.00010830386 1.1318343e-07
[0.00047492 0.00078218 0.00057791 0.00075334 0.00317502]
[0.0004738 0.00075904 0.0004524 0.00064938 0.0024273 ]
================================================================================
resnet18v2.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add1>
4.053116e-06 8.2214177e-07 0.0
[-2.7655013 0.78636014 0.33238387 0.44985968 0.8458633 ]
[-2.7655027 0.7863597 0.33238354 0.44986013 0.8458635 ]
================================================================================
resnet34v2.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add2>
4.2915344e-06 1.0785647e-06 0.0
[-2.2723439 0.04653598 0.30353007 0.2595243 0.41742972]
[-2.2723436 0.04653586 0.30352995 0.2595226 0.41742748]
================================================================================
squeezenet1.1.onnx :
Input Blob Names: ['data']
<Symbol squeezenet0_flatten0_reshape0>
5.2452087e-06 7.539566e-07 0.0
[0.5420265 4.364538 4.056094 4.889414 4.368317 ]
[0.5420265 4.364538 4.056096 4.889414 4.368317 ]
================================================================================
mobilenetv2-1.0.onnx :
Input Blob Names: ['data']
<Symbol mobilenetv20_output_flatten0_reshape0>
7.390976e-06 1.855201e-06 0.0
[-3.0427516 1.1981574 0.29107922 -0.53182554 1.0944995 ]
[-3.04275 1.198158 0.29107827 -0.5318265 1.0945004 ]
================================================================================
alex_net.onnx :
Input Blob Names: ['data_0']
<Symbol softmax1>
0.00015634485 9.566847e-06 7.887138e-09
[0.00150446 0.00085357 0.00193772 0.00269155 0.00280163]
[0.0015188 0.0008455 0.00191552 0.00266171 0.0027587 ]
================================================================================
densenet121.onnx :
Input Blob Names: ['data_0']
<Symbol convolution318>
5.00679e-06 8.735047e-07 0.0
[[[-1.742769 ]]
[[-0.05456648]]
[[ 1.0543407 ]]
[[ 0.522862 ]]
[[ 0.82946616]]]
[[[-1.7427675 ]]
[[-0.05456671]]
[[ 1.054343 ]]
[[ 0.5228648 ]]
[[ 0.82946706]]]
================================================================================
vgg16.onnx :
Input Blob Names: ['data']
<Symbol broadcast_add129>
1.1920929e-06 2.18953e-07 0.0
[-1.8850654 1.1514534 -0.01803813 0.7195613 -0.06146571]
[-1.8850652 1.1514536 -0.01803814 0.7195611 -0.06146575]
之前的记录链接: