-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Expected behavior
TVM should run the model correctly.
Actual behavior
For the following model,
it can be executed by onnxruntime, the results are as follows:
ONNXRuntime:
[array([5.627121 , 3.1315434, 6.488241 , 4.1415935, 5.716218 , 4.9254465, 4.222067 , 6.5320206, 4.5028763, 4.85569 , 4.468175 , 6.8016195, 5.443865 , 5.117418 , 5.684251 , 6.9260955, 4.885092 ,
5.473915 , 5.516653 , 5.740866 , 5.9058757, 4.898214 , 5.5144825, 5.5342417, 4.397482 , 4.531957 , 4.3176513, 3.5980804, 4.0122795, 3.5998032, 5.380638 , 5.7510695], dtype=float32)]
However, when compiling and running the model using TVM, TVM crashes:
tvm.error.InternalError: In Op(relax.nn.prelu), the input axis 1 is out of range. The input tensor has 1 dimensions, so axis should be in range [-1, 1).
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14
Steps to reproduce
This bug can be reproduced by the following code with the model in the attachment.
model.zip
from __future__ import annotations
import argparse
import os
import pickle
import sys
import traceback
from pathlib import Path
import numpy as np
def _ensure_repo_tvm() -> None:
repo_root = Path(__file__).resolve().parents[3]
tvm_python = repo_root / "tvm" / "python"
tvm_build = repo_root / "tvm" / "build"
if tvm_python.exists():
sys.path.insert(0, tvm_python.as_posix())
if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()
def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
obj = pickle.loads(path.read_bytes())
inp = obj.get("input", obj)
if not isinstance(inp, dict):
raise ValueError("oracle.pkl does not contain a dict input")
return {k: np.array(v) for k, v in inp.items()}
def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
import onnxruntime as ort # type: ignore
np.set_printoptions(threshold=np.inf, linewidth=200)
sess = ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"])
outs = sess.run(None, inputs)
out_names = [o.name for o in sess.get_outputs()]
outs_np = [np.array(v) for v in outs]
print("ONNXRuntime:\n", outs_np)
for n, a in zip(out_names, outs_np):
print("[ort] output", n, "shape=", a.shape, "dtype=", a.dtype, "min/max=", (a.min(), a.max()))
def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
_ensure_repo_tvm()
import onnx # type: ignore
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend import onnx as rx_onnx # type: ignore
onnx_model = onnx.load(model_path.as_posix())
shape_dict = {k: v.shape for k, v in inputs.items()}
print("[tvm] shape_dict:", shape_dict)
try:
converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict)
mod = converted[0] if isinstance(converted, (list, tuple)) else converted
mod = relax.transform.DecomposeOpsForInference()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod, params = relax.frontend.detach_params(mod)
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
_ = relax.build(mod, target=tgt, params=params, relax_pipeline=pipeline)
print("[tvm] UNEXPECTED: succeeded (no axis error)")
except Exception as e:
print("[tvm] FAILED:", type(e).__name__)
traceback.print_exc()
print("\n[tvm] error message:\n" + str(e))
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--model", type=Path, default=Path("model.onnx"))
ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
args = ap.parse_args()
model_path = args.model.resolve()
oracle_path = args.oracle.resolve()
inputs = _load_oracle_inputs(oracle_path)
_run_ort(model_path, inputs)
_run_tvm(model_path, inputs)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Triage
- needs-triage
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug