-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
TVM should run the model correctly.
Actual behavior
For the following simple model,
it can be executed by onnxruntime and onnx's ReferenceEvaluator, the results are as follows:
[[[[-14.942039 25.242477 12.100965 ... 57.370975
43.9531 2.7949858 ]
[ 10.37292 -10.740159 -38.21392 ... -25.30576
20.154476 -19.093735 ]
[ 23.041002 41.820953 -34.108677 ... 13.982165
7.2081375 11.132856 ]
...
[ -5.590187 -30.208424 9.572759 ... 32.32894
10.533618 -10.12186 ]
[-11.91229 29.118868 -20.662777 ... -16.910463
-3.0499296 -8.469831 ]
[ -1.340055 -17.950693 1.7485574 ... 44.479458
11.766148 -29.151772 ]]
[[ 3.5306716 -17.110134 -12.246135 ... 3.6663613
10.713657 18.570526 ]
[ 14.867607 -7.739101 -32.685593 ... -8.97093
-14.347578 32.1913 ]
[ 7.462762 -14.961873 -18.268976 ... -16.221258
9.145517 22.871971 ]
...
[ 58.016407 -28.310326 27.333632 ... 33.982685
-0.22384354 44.509457 ]
[-27.212149 -17.831676 7.140829 ... 10.431597
33.523167 -36.84701 ]
[ 4.2589498 -18.459444 9.673733 ... -21.87571
11.560403 -11.319146 ]]
[[ 1.6221294 -24.357258 37.832687 ... -15.675017
17.112104 -19.541176 ]
[ -1.9535408 -43.395298 -7.589774 ... 32.80785
39.018574 -9.826303 ]
[ 23.802937 -22.675396 2.9990005 ... -29.012583
55.184933 16.124914 ]
...
[ 13.447808 1.2474155 68.31904 ... -33.191654
7.361887 -14.480252 ]
[ 8.724002 -44.378754 -3.259966 ... -43.776226
47.07514 4.846405 ]
[-28.980995 0.75377727 -20.248524 ... -13.6773405
34.299976 -5.997253 ]]
...
[[-33.108776 -16.460379 7.451367 ... -3.0716655
8.792357 9.199628 ]
[-10.372536 -22.135944 -21.878239 ... 55.7764
-7.191108 -14.206541 ]
[-53.47607 10.408365 -20.441448 ... 23.109488
1.2017591 6.3908386 ]
...
[ -0.9444683 -29.720179 -36.938797 ... -32.23233
-35.5812 42.359478 ]
[-28.41974 -0.1681884 -38.875988 ... 14.968706
-23.108547 -32.480797 ]
[ 2.9518757 1.1772063 -21.35058 ... -24.29589
1.791322 -2.575415 ]]
[[ -1.1625404 -6.98105 -10.189097 ... -26.975689
-14.790653 -8.432305 ]
[ 17.315321 34.95646 -33.681236 ... -8.27439
9.273292 -20.079172 ]
[ -9.413344 25.616909 -8.439555 ... -17.31526
18.547462 -18.068047 ]
...
[ 0.11532002 -46.406166 -29.583313 ... -17.339268
17.583694 38.199493 ]
[-49.912685 10.065314 -67.41777 ... -17.945366
-17.426779 24.320415 ]
[ -9.248135 4.745124 -5.5740147 ... 33.417686
-3.3946886 4.773201 ]]
[[-11.940619 1.2533135 -15.480752 ... 17.695894
23.379114 -6.3436594 ]
[ -2.8251894 -1.4625757 12.186805 ... -25.161877
6.0388117 -8.681617 ]
[ 37.58026 7.1358566 -53.616978 ... 49.06318
-9.981028 15.518544 ]
...
[-19.931892 23.587257 -68.27191 ... -25.605322
-31.568523 0.6990657 ]
[ 56.762306 -4.785701 3.260337 ... 2.4171884
-50.66727 -23.977222 ]
[ 13.220358 -13.136884 18.563856 ... -9.229172
-32.588108 27.032162 ]]]]
However, the onnx frontend of TVM cannot import it:
File "/home/ubuntu/Documents/test1.py", line 49, in test
tvm_model = from_onnx(onnx_model, opset=14, keep_params_in_input=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 4260, in from_onnx
return g.from_onnx(graph, opset)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 3890, in from_onnx
self._construct_nodes(graph)
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 4071, in _construct_nodes
raise err
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 4068, in _construct_nodes
op = self.bb.normalize(op)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/block_builder.py", line 672, in normalize
return _ffi_api.BlockBuilderNormalize(self, expr) # type: ignore
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__
File "<unknown>", line 0, in tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
File "<unknown>", line 0, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
File "<unknown>", line 0, in tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
File "<unknown>", line 0, in tvm::relax::InferStructInfoBroadcastArith(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
File "<unknown>", line 0, in tvm::relax::StructInfo tvm::relax::InferStructInfoBroadcast<tvm::runtime::DataType (*)(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::StructInfo const&, tvm::relax::StructInfo const&)>(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::runtime::DataType (*)(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::StructInfo const&, tvm::relax::StructInfo const&))
File "<unknown>", line 0, in tvm::relax::InferBinaryBroadcastShape(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::ffi::Array<tvm::PrimExpr, void> const&, tvm::ffi::Array<tvm::PrimExpr, void> const&)
File "<unknown>", line 0, in tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.InternalError: In Op(relax.add), the first input shape at dim 1 is T.int64(16) and the second input shape at dim 1 is T.int64(32), which are not broadcastable.
[11:03:35] /home/ubuntu/Documents/DLCompilers/tvm/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!
Environment
OS: Ubuntu 20.04
TVM: 0.23.dev0 (f4e28d3)
onnxruntime: 1.23.2
Steps to reproduce
This bug can be reproduced by the following code with the model in the attachment.
from typing import Dict, List, Literal, Optional
import sys
import os
import numpy as np
import onnx
from onnx.reference import ReferenceEvaluator
import onnxruntime
from onnx import ModelProto, TensorProto, helper
import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
import argparse
import pickle
def test() -> None:
onnx_model = onnx.load("11.onnx")
# Configure model format.
onnx_model.ir_version = 8
onnx_model.opset_import[0].version = 14
with open("inputs.pkl", 'rb') as fp:
inputs = pickle.load(fp)
# Run the model through onnx to get the expected result.
try:
ort_session = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
)
ort_output = ort_session.run([], inputs)
except Exception as e:
print(e)
print("This model cannot be executed by onnxruntime!")
sys.exit(1)
print("onnxruntime:", ort_output[0])
# ReferenceEvaluator
sess = ReferenceEvaluator("11.onnx")
re_output = sess.run(None, inputs)
print("ReferenceEvaluator:", re_output[0])
tvm.testing.assert_allclose(re_output[0], ort_output[0], rtol=0.1, atol=0.1)
# TVM
tvm_model = from_onnx(onnx_model, opset=14, keep_params_in_input=True)
if __name__ == "__main__":
test()
Triage
Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).
- needs-triage
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug