-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
TVM shows inconsistent ONNX Cast(to=BOOL) behavior for NaN depending on how NaN is produced:
- Direct NaN constant:
Constant(NaN) -> CastreturnsTrue(matches ONNX Runtime / PyTorch). - NaN produced by computation:
x -> (NaN-producing op) -> CastreturnsFalsein TVM, while ONNX Runtime / PyTorch returnTrue.
Expected behavior
Per ONNX Cast operator spec for casting from floating point to bool:
+/-0.0→False- all else →
True
Therefore:
Cast(NaN -> bool)should beTrue(NaN is not+0.0/-0.0, so it falls under “all else”).- In this repro,
Asin(5.0)is NaN because arcsine’s real domain is[-1, 1], so the final output should beTrue.
Actual behavior
Taking this model as an example:
Repro model (computed NaN → Cast): Constant(5.0) -> Asin -> Cast(to=BOOL) (opset 18, input-free)
- ONNX Runtime:
True - PyTorch:
True - TVM (Relax, LLVM target):
False
And we have also tried other possible ways to generate NAN:
Asin(x)withx=5.0Acos(x)withx=2.0Sqrt(x)withx=-1.0Log(x)withx=-1.0Div(x, x)withx=0.0(0/0)
The results are consistent with the above.
Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
python:3.11.14
Steps to reproduce
Download the model and run the following code to obtain the results.
python cast_compare.py --model model.onnx
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
import numpy as np
import onnx
def _ensure_repo_tvm() -> None:
"""
Avoid mixing:
- repo TVM python (newer)
- site-packages TVM runtime (older)
Force-import TVM from this repo's `tvm/python`, and point TVM to `tvm/build`.
"""
repo_root = Path(__file__).resolve().parents[3]
tvm_python = repo_root / "tvm" / "python"
tvm_build = repo_root / "tvm" / "build"
if tvm_python.exists():
sys.path.insert(0, tvm_python.as_posix())
if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()
for k in list(sys.modules.keys()):
if k == "tvm" or k.startswith("tvm."):
del sys.modules[k]
def _run_torch() -> bool | None:
try:
import torch
except Exception:
return None
# Directly test the Cast semantics on NaN.
a = torch.tensor(float("nan"), dtype=torch.float32)
y = a.to(torch.bool)
return bool(y.item())
def _run_ort(model_bytes: bytes) -> bool:
import onnxruntime as ort # type: ignore
sess = ort.InferenceSession(model_bytes, providers=["CPUExecutionProvider"])
outs = sess.run(None, {})
if len(outs) != 1:
raise RuntimeError(f"ORT returned {len(outs)} outputs, expected 1")
y = np.array(outs[0]).item()
return bool(y)
def _run_tvm(model_path: Path) -> bool:
_ensure_repo_tvm()
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend import onnx as rx_onnx # type: ignore
onnx_model = onnx.load(model_path.as_posix())
converted = rx_onnx.from_onnx(onnx_model, shape_dict={})
mod = converted[0] if isinstance(converted, (list, tuple)) else converted
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}):
ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main")
vm.invoke_stateful("main")
out = vm.get_outputs("main")
if isinstance(out, tuple):
out = out[0]
if hasattr(out, "numpy"):
arr = out.numpy()
else:
arr = np.array(out)
return bool(np.array(arr).item())
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--model", type=Path, default=Path("cast_nan_to_bool.onnx"))
args = ap.parse_args()
model_path = args.model.resolve()
if not model_path.exists():
print("error: model not found:", model_path)
return 1
model_bytes = model_path.read_bytes()
y_ort = _run_ort(model_bytes)
y_torch = _run_torch()
y_tvm = _run_tvm(model_path)
# Minimal output: just the three backend results.
print("ort :", y_ort)
print("torch:", "skip" if y_torch is None else y_torch)
print("tvm :", y_tvm)
return 0
if __name__ == "__main__":
raise SystemExit(main())
Triage
- needs-triage
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug