Skip to content

[Bug] ONNX Cast treats NaN inconsistently in TVM LLVM codegen: Constant(NaN)->True but computed NaN->False #18605

@dutZ1855

Description

@dutZ1855

TVM shows inconsistent ONNX Cast(to=BOOL) behavior for NaN depending on how NaN is produced:

  • Direct NaN constant: Constant(NaN) -> Cast returns True (matches ONNX Runtime / PyTorch).
  • NaN produced by computation: x -> (NaN-producing op) -> Cast returns False in TVM, while ONNX Runtime / PyTorch return True.

Expected behavior

Per ONNX Cast operator spec for casting from floating point to bool:

  • +/-0.0False
  • all elseTrue

Therefore:

  • Cast(NaN -> bool) should be True (NaN is not +0.0/-0.0, so it falls under “all else”).
  • In this repro, Asin(5.0) is NaN because arcsine’s real domain is [-1, 1], so the final output should be True.

Actual behavior

Taking this model as an example:

Image

Repro model (computed NaN → Cast): Constant(5.0) -> Asin -> Cast(to=BOOL) (opset 18, input-free)


  • ONNX Runtime: True
  • PyTorch: True
  • TVM (Relax, LLVM target): False

And we have also tried other possible ways to generate NAN:

  • Asin(x) with x=5.0
  • Acos(x) with x=2.0
  • Sqrt(x) with x=-1.0
  • Log(x) with x=-1.0
  • Div(x, x) with x=0.0 (0/0)

The results are consistent with the above.

Environment

Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14

Steps to reproduce

model.zip

Download the model and run the following code to obtain the results.

python cast_compare.py --model model.onnx

from __future__ import annotations

import argparse
import os
import sys
from pathlib import Path

import numpy as np
import onnx


def _ensure_repo_tvm() -> None:
    """
    Avoid mixing:
      - repo TVM python (newer)
      - site-packages TVM runtime (older)

    Force-import TVM from this repo's `tvm/python`, and point TVM to `tvm/build`.
    """
    repo_root = Path(__file__).resolve().parents[3]
    tvm_python = repo_root / "tvm" / "python"
    tvm_build = repo_root / "tvm" / "build"

    if tvm_python.exists():
        sys.path.insert(0, tvm_python.as_posix())

    if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
        os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()

    for k in list(sys.modules.keys()):
        if k == "tvm" or k.startswith("tvm."):
            del sys.modules[k]


def _run_torch() -> bool | None:
    try:
        import torch  
    except Exception:
        return None

    # Directly test the Cast semantics on NaN.
    a = torch.tensor(float("nan"), dtype=torch.float32)
    y = a.to(torch.bool)
    return bool(y.item())


def _run_ort(model_bytes: bytes) -> bool:
    import onnxruntime as ort  # type: ignore

    sess = ort.InferenceSession(model_bytes, providers=["CPUExecutionProvider"])
    outs = sess.run(None, {})
    if len(outs) != 1:
        raise RuntimeError(f"ORT returned {len(outs)} outputs, expected 1")
    y = np.array(outs[0]).item()
    return bool(y)


def _run_tvm(model_path: Path) -> bool:
    _ensure_repo_tvm()
    import tvm  # type: ignore
    from tvm import relax  # type: ignore
    from tvm.relax.frontend import onnx as rx_onnx  # type: ignore

    onnx_model = onnx.load(model_path.as_posix())
    converted = rx_onnx.from_onnx(onnx_model, shape_dict={})
    mod = converted[0] if isinstance(converted, (list, tuple)) else converted

    tgt = tvm.target.Target("llvm")
    pipeline = relax.pipeline.get_default_pipeline(tgt)
    with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
        ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
    vm = relax.VirtualMachine(ex, tvm.cpu())

    vm.set_input("main")
    vm.invoke_stateful("main")
    out = vm.get_outputs("main")

    if isinstance(out, tuple):
        out = out[0]
    if hasattr(out, "numpy"):
        arr = out.numpy()
    else:
        arr = np.array(out)
    return bool(np.array(arr).item())


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=Path, default=Path("cast_nan_to_bool.onnx"))
    args = ap.parse_args()

    model_path = args.model.resolve()
    if not model_path.exists():
        print("error: model not found:", model_path)
        return 1

    model_bytes = model_path.read_bytes()

    y_ort = _run_ort(model_bytes)
    y_torch = _run_torch()
    y_tvm = _run_tvm(model_path)

    # Minimal output: just the three backend results.
    print("ort  :", y_ort)
    print("torch:", "skip" if y_torch is None else y_torch)
    print("tvm  :", y_tvm)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions