Skip to content

[Bug] [FRONTEND][ONNX] Error converting operator ConvTranspose: InternalError: In Op(relax.add), the first input shape at dim 1 is T.int64(16) and the second input shape at dim 1 is T.int64(32), which are not broadcastable. #18600

@coffezhou

Description

@coffezhou

Expected behavior

TVM should run the model correctly.

Actual behavior

For the following simple model,

Image

it can be executed by onnxruntime and onnx's ReferenceEvaluator, the results are as follows:

[[[[-14.942039    25.242477    12.100965   ...  57.370975
     43.9531       2.7949858 ]
   [ 10.37292    -10.740159   -38.21392    ... -25.30576
     20.154476   -19.093735  ]
   [ 23.041002    41.820953   -34.108677   ...  13.982165
      7.2081375   11.132856  ]
   ...
   [ -5.590187   -30.208424     9.572759   ...  32.32894
     10.533618   -10.12186   ]
   [-11.91229     29.118868   -20.662777   ... -16.910463
     -3.0499296   -8.469831  ]
   [ -1.340055   -17.950693     1.7485574  ...  44.479458
     11.766148   -29.151772  ]]

  [[  3.5306716  -17.110134   -12.246135   ...   3.6663613
     10.713657    18.570526  ]
   [ 14.867607    -7.739101   -32.685593   ...  -8.97093
    -14.347578    32.1913    ]
   [  7.462762   -14.961873   -18.268976   ... -16.221258
      9.145517    22.871971  ]
   ...
   [ 58.016407   -28.310326    27.333632   ...  33.982685
     -0.22384354  44.509457  ]
   [-27.212149   -17.831676     7.140829   ...  10.431597
     33.523167   -36.84701   ]
   [  4.2589498  -18.459444     9.673733   ... -21.87571
     11.560403   -11.319146  ]]

  [[  1.6221294  -24.357258    37.832687   ... -15.675017
     17.112104   -19.541176  ]
   [ -1.9535408  -43.395298    -7.589774   ...  32.80785
     39.018574    -9.826303  ]
   [ 23.802937   -22.675396     2.9990005  ... -29.012583
     55.184933    16.124914  ]
   ...
   [ 13.447808     1.2474155   68.31904    ... -33.191654
      7.361887   -14.480252  ]
   [  8.724002   -44.378754    -3.259966   ... -43.776226
     47.07514      4.846405  ]
   [-28.980995     0.75377727 -20.248524   ... -13.6773405
     34.299976    -5.997253  ]]

  ...

  [[-33.108776   -16.460379     7.451367   ...  -3.0716655
      8.792357     9.199628  ]
   [-10.372536   -22.135944   -21.878239   ...  55.7764
     -7.191108   -14.206541  ]
   [-53.47607     10.408365   -20.441448   ...  23.109488
      1.2017591    6.3908386 ]
   ...
   [ -0.9444683  -29.720179   -36.938797   ... -32.23233
    -35.5812      42.359478  ]
   [-28.41974     -0.1681884  -38.875988   ...  14.968706
    -23.108547   -32.480797  ]
   [  2.9518757    1.1772063  -21.35058    ... -24.29589
      1.791322    -2.575415  ]]

  [[ -1.1625404   -6.98105    -10.189097   ... -26.975689
    -14.790653    -8.432305  ]
   [ 17.315321    34.95646    -33.681236   ...  -8.27439
      9.273292   -20.079172  ]
   [ -9.413344    25.616909    -8.439555   ... -17.31526
     18.547462   -18.068047  ]
   ...
   [  0.11532002 -46.406166   -29.583313   ... -17.339268
     17.583694    38.199493  ]
   [-49.912685    10.065314   -67.41777    ... -17.945366
    -17.426779    24.320415  ]
   [ -9.248135     4.745124    -5.5740147  ...  33.417686
     -3.3946886    4.773201  ]]

  [[-11.940619     1.2533135  -15.480752   ...  17.695894
     23.379114    -6.3436594 ]
   [ -2.8251894   -1.4625757   12.186805   ... -25.161877
      6.0388117   -8.681617  ]
   [ 37.58026      7.1358566  -53.616978   ...  49.06318
     -9.981028    15.518544  ]
   ...
   [-19.931892    23.587257   -68.27191    ... -25.605322
    -31.568523     0.6990657 ]
   [ 56.762306    -4.785701     3.260337   ...   2.4171884
    -50.66727    -23.977222  ]
   [ 13.220358   -13.136884    18.563856   ...  -9.229172
    -32.588108    27.032162  ]]]]

However, the onnx frontend of TVM cannot import it:

File "/home/ubuntu/Documents/test1.py", line 49, in test
    tvm_model = from_onnx(onnx_model, opset=14, keep_params_in_input=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 4260, in from_onnx
    return g.from_onnx(graph, opset)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 3890, in from_onnx
    self._construct_nodes(graph)
  File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 4071, in _construct_nodes
    raise err
  File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/frontend/onnx/onnx_frontend.py", line 4068, in _construct_nodes
    op = self.bb.normalize(op)
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/Documents/DLCompilers/tvm/python/tvm/relax/block_builder.py", line 672, in normalize
    return _ffi_api.BlockBuilderNormalize(self, expr)  # type: ignore
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__
  File "<unknown>", line 0, in tvm::relax::Normalizer::Normalize(tvm::RelaxExpr const&)
  File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&)
  File "<unknown>", line 0, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
  File "<unknown>", line 0, in tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  File "<unknown>", line 0, in tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  File "<unknown>", line 0, in tvm::relax::InferStructInfoBroadcastArith(tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  File "<unknown>", line 0, in tvm::relax::StructInfo tvm::relax::InferStructInfoBroadcast<tvm::runtime::DataType (*)(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::StructInfo const&, tvm::relax::StructInfo const&)>(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::runtime::DataType (*)(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::relax::StructInfo const&, tvm::relax::StructInfo const&))
  File "<unknown>", line 0, in tvm::relax::InferBinaryBroadcastShape(tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::ffi::Array<tvm::PrimExpr, void> const&, tvm::ffi::Array<tvm::PrimExpr, void> const&)
  File "<unknown>", line 0, in tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.InternalError: In Op(relax.add), the first input shape at dim 1 is T.int64(16) and the second input shape at dim 1 is T.int64(32), which are not broadcastable.
[11:03:35] /home/ubuntu/Documents/DLCompilers/tvm/src/relax/ir/block_builder.cc:64: Warning: BlockBuilder destroyed with remaining blocks!

Environment

OS: Ubuntu 20.04
TVM: 0.23.dev0 (f4e28d3)

onnxruntime: 1.23.2

Steps to reproduce

This bug can be reproduced by the following code with the model in the attachment.

from typing import Dict, List, Literal, Optional
import sys
import os

import numpy as np
import onnx
from onnx.reference import ReferenceEvaluator

import onnxruntime
from onnx import ModelProto, TensorProto, helper

import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx

import argparse
import pickle

def test() -> None:
    onnx_model = onnx.load("11.onnx")
    # Configure model format.
    onnx_model.ir_version = 8
    onnx_model.opset_import[0].version = 14
    
    with open("inputs.pkl", 'rb') as fp:
        inputs = pickle.load(fp)
    # Run the model through onnx to get the expected result.
    try:
        ort_session = onnxruntime.InferenceSession(
            onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]
        )
        ort_output = ort_session.run([], inputs)
    except Exception as e:
        print(e)
        print("This model cannot be executed by onnxruntime!")
        sys.exit(1)

    print("onnxruntime:", ort_output[0])

    # ReferenceEvaluator
    sess = ReferenceEvaluator("11.onnx")
    re_output = sess.run(None, inputs)
    print("ReferenceEvaluator:", re_output[0])

    tvm.testing.assert_allclose(re_output[0], ort_output[0], rtol=0.1, atol=0.1)

    # TVM
    tvm_model = from_onnx(onnx_model, opset=14, keep_params_in_input=True)

    
    
if __name__ == "__main__":
    test()
 

testcase.zip

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

cc @KJlaccHoeUM9l

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions