Skip to content

[Bug] PRelu with 1-D input fails to import in Relax: relax.nn.prelu uses axis=1 out of range #18606

@dutZ1855

Description

@dutZ1855

Expected behavior

TVM should run the model correctly.

Actual behavior

For the following model,

Image Image

it can be executed by onnxruntime, the results are as follows:

ONNXRuntime:

 [array([5.627121 , 3.1315434, 6.488241 , 4.1415935, 5.716218 , 4.9254465, 4.222067 , 6.5320206, 4.5028763, 4.85569  , 4.468175 , 6.8016195, 5.443865 , 5.117418 , 5.684251 , 6.9260955, 4.885092 ,
       5.473915 , 5.516653 , 5.740866 , 5.9058757, 4.898214 , 5.5144825, 5.5342417, 4.397482 , 4.531957 , 4.3176513, 3.5980804, 4.0122795, 3.5998032, 5.380638 , 5.7510695], dtype=float32)]

However, when compiling and running the model using TVM, TVM crashes:
tvm.error.InternalError: In Op(relax.nn.prelu), the input axis 1 is out of range. The input tensor has 1 dimensions, so axis should be in range [-1, 1).

Environment

Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14

Steps to reproduce

This bug can be reproduced by the following code with the model in the attachment.
model.zip

from __future__ import annotations

import argparse
import os
import pickle
import sys
import traceback
from pathlib import Path

import numpy as np


def _ensure_repo_tvm() -> None:
    repo_root = Path(__file__).resolve().parents[3]
    tvm_python = repo_root / "tvm" / "python"
    tvm_build = repo_root / "tvm" / "build"
    if tvm_python.exists():
        sys.path.insert(0, tvm_python.as_posix())
    if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
        os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()


def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
    obj = pickle.loads(path.read_bytes())
    inp = obj.get("input", obj)
    if not isinstance(inp, dict):
        raise ValueError("oracle.pkl does not contain a dict input")
    return {k: np.array(v) for k, v in inp.items()}


def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
    import onnxruntime as ort  # type: ignore

    np.set_printoptions(threshold=np.inf, linewidth=200)
    sess = ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"])
    outs = sess.run(None, inputs)
    out_names = [o.name for o in sess.get_outputs()]
    outs_np = [np.array(v) for v in outs]
    print("ONNXRuntime:\n", outs_np)
    for n, a in zip(out_names, outs_np):
        print("[ort] output", n, "shape=", a.shape, "dtype=", a.dtype, "min/max=", (a.min(), a.max()))


def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
    _ensure_repo_tvm()
    import onnx  # type: ignore
    import tvm  # type: ignore
    from tvm import relax  # type: ignore
    from tvm.relax.frontend import onnx as rx_onnx  # type: ignore

    onnx_model = onnx.load(model_path.as_posix())
    shape_dict = {k: v.shape for k, v in inputs.items()}
    print("[tvm] shape_dict:", shape_dict)
    try:
        converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict)
        mod = converted[0] if isinstance(converted, (list, tuple)) else converted
        mod = relax.transform.DecomposeOpsForInference()(mod)
        mod = relax.transform.LegalizeOps()(mod)
        mod, params = relax.frontend.detach_params(mod)
        tgt = tvm.target.Target("llvm")
        pipeline = relax.pipeline.get_default_pipeline(tgt)
        with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
            _ = relax.build(mod, target=tgt, params=params, relax_pipeline=pipeline)
        print("[tvm] UNEXPECTED: succeeded (no axis error)")
    except Exception as e:
        print("[tvm] FAILED:", type(e).__name__)
        traceback.print_exc()
        print("\n[tvm] error message:\n" + str(e))


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=Path, default=Path("model.onnx"))
    ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
    args = ap.parse_args()

    model_path = args.model.resolve()
    oracle_path = args.oracle.resolve()
    inputs = _load_oracle_inputs(oracle_path)

    _run_ort(model_path, inputs)
    _run_tvm(model_path, inputs)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions