Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions functions/src/onnx_utils/function.yaml

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions functions/src/onnx_utils/item.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ labels:
author: Iguazio
maintainers: []
marketplaceType: ''
mlrunVersion: 1.7.2
mlrunVersion: 1.10.0
name: onnx_utils
platformVersion: 3.5.0
spec:
Expand All @@ -30,13 +30,13 @@ spec:
- tqdm~=4.67.1
- tensorflow~=2.19.0
- tf_keras~=2.19.0
- torch~=2.6.0
- torchvision~=0.21.0
- torch~=2.8.0
- torchvision~=0.23.0
- onnx~=1.17.0
- onnxruntime~=1.19.2
- onnxoptimizer~=0.3.13
- onnxmltools~=1.13.0
- tf2onnx~=1.16.1
- plotly~=5.23
url: ''
version: 1.3.0
version: 1.4.0
7 changes: 3 additions & 4 deletions functions/src/onnx_utils/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
tqdm~=4.67.1
tensorflow~=2.19.0
tf_keras~=2.19.0
torch~=2.6.0
torchvision~=0.21.0
torch~=2.8
torchvision~=0.23.0
onnx~=1.17.0
onnxruntime~=1.19.2
onnxoptimizer~=0.3.13
onnxmltools~=1.13.0
tf2onnx~=1.16.1
plotly~=5.23
plotly~=5.23
108 changes: 91 additions & 17 deletions functions/src/onnx_utils/test_onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
import tempfile

import mlrun
import pytest

# Project name for tests (must match conftest.py)
PROJECT_NAME = "onnx-utils"

# Choose our model's name:
MODEL_NAME = "model"
Expand All @@ -27,6 +31,30 @@
# Choose our optimized ONNX version model's name:
OPTIMIZED_ONNX_MODEL_NAME = f"optimized_{ONNX_MODEL_NAME}"

REQUIRED_ENV_VARS = [
"MLRUN_DBPATH",
"MLRUN_ARTIFACT_PATH",
]


def _validate_environment_variables() -> bool:
"""
Checks that all required Environment variables are set.
"""
environment_keys = os.environ.keys()
return all(key in environment_keys for key in REQUIRED_ENV_VARS)


def _is_tf2onnx_available() -> bool:
"""
Check if tf2onnx is installed (required for TensorFlow/Keras ONNX conversion).
"""
try:
import tf2onnx
return True
except ImportError:
return False


def _setup_environment() -> str:
"""
Expand All @@ -52,6 +80,11 @@ def _cleanup_environment(artifact_path: str):
"runs",
"artifacts",
"functions",
"model.pt",
"model.zip",
"model_modules_map.json",
"onnx_model.onnx",
"optimized_onnx_model.onnx",
]:
test_output_path = os.path.abspath(f"./{test_output}")
if os.path.exists(test_output_path):
Expand Down Expand Up @@ -114,6 +147,14 @@ def _log_pytorch_model(context: mlrun.MLClientCtx, model_name: str):
model_handler.log()


@pytest.mark.skipif(
condition=not _validate_environment_variables(),
reason="Project's environment variables are not set",
)
@pytest.mark.skipif(
condition=not _is_tf2onnx_available(),
reason="tf2onnx is not installed",
)
def test_to_onnx_help():
"""
Test the 'to_onnx' handler, passing "help" in the 'framework_kwargs'.
Expand All @@ -125,27 +166,28 @@ def test_to_onnx_help():
log_model_function = mlrun.code_to_function(
filename="test_onnx_utils.py",
name="log_model",
project=PROJECT_NAME,
kind="job",
image="mlrun/ml-models",
)

# Run the function to log the model:
log_model_run = log_model_function.run(
handler="_log_tf_keras_model",
artifact_path=artifact_path,
output_path=artifact_path,
params={"model_name": MODEL_NAME},
local=True,
)

# Import the ONNX Utils function:
onnx_function = mlrun.import_function("function.yaml")
onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME)

# Run the function, passing "help" in 'framework_kwargs' and see that no exception was raised:
is_test_passed = True
try:
onnx_function.run(
handler="to_onnx",
artifact_path=artifact_path,
output_path=artifact_path,
params={
# Take the logged model from the previous function.
"model_path": log_model_run.status.artifacts[0]["spec"]["target_path"],
Expand All @@ -166,6 +208,14 @@ def test_to_onnx_help():
assert is_test_passed


@pytest.mark.skipif(
condition=not _validate_environment_variables(),
reason="Project's environment variables are not set",
)
@pytest.mark.skipif(
condition=not _is_tf2onnx_available(),
reason="tf2onnx is not installed",
)
def test_tf_keras_to_onnx():
"""
Test the 'to_onnx' handler, giving it a tf.keras model.
Expand All @@ -177,25 +227,26 @@ def test_tf_keras_to_onnx():
log_model_function = mlrun.code_to_function(
filename="test_onnx_utils.py",
name="log_model",
project=PROJECT_NAME,
kind="job",
image="mlrun/ml-models",
)

# Run the function to log the model:
log_model_run = log_model_function.run(
handler="_log_tf_keras_model",
artifact_path=artifact_path,
output_path=artifact_path,
params={"model_name": MODEL_NAME},
local=True,
)

# Import the ONNX Utils function:
onnx_function = mlrun.import_function("function.yaml")
onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME)

# Run the function to convert our model to ONNX:
onnx_function_run = onnx_function.run(
handler="to_onnx",
artifact_path=artifact_path,
output_path=artifact_path,
params={
# Take the logged model from the previous function.
"model_path": log_model_run.status.artifacts[0]["spec"]["target_path"],
Expand All @@ -215,6 +266,10 @@ def test_tf_keras_to_onnx():
assert "model" in onnx_function_run.outputs


@pytest.mark.skipif(
condition=not _validate_environment_variables(),
reason="Project's environment variables are not set",
)
def test_pytorch_to_onnx():
"""
Test the 'to_onnx' handler, giving it a pytorch model.
Expand All @@ -226,32 +281,38 @@ def test_pytorch_to_onnx():
log_model_function = mlrun.code_to_function(
filename="test_onnx_utils.py",
name="log_model",
project=PROJECT_NAME,
kind="job",
image="mlrun/ml-models",
)

# Run the function to log the model:
log_model_run = log_model_function.run(
handler="_log_pytorch_model",
artifact_path=artifact_path,
output_path=artifact_path,
params={"model_name": MODEL_NAME},
local=True,
)

# Import the ONNX Utils function:
onnx_function = mlrun.import_function("function.yaml")
onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME)

# Get artifact paths - construct from artifact_path and run structure
run_artifact_dir = os.path.join(artifact_path, "log-model--log-pytorch-model", "0")
model_path = os.path.join(run_artifact_dir, "model")
modules_map_path = os.path.join(run_artifact_dir, "model_modules_map.json.json")

# Run the function to convert our model to ONNX:
onnx_function_run = onnx_function.run(
handler="to_onnx",
artifact_path=artifact_path,
output_path=artifact_path,
params={
# Take the logged model from the previous function.
"model_path": log_model_run.status.artifacts[1]["spec"]["target_path"],
"model_path": model_path,
"load_model_kwargs": {
"model_name": MODEL_NAME,
"model_class": "mobilenet_v2",
"modules_map": log_model_run.status.artifacts[0]["spec"]["target_path"],
"modules_map": modules_map_path,
},
"onnx_model_name": ONNX_MODEL_NAME,
"framework_kwargs": {"input_signature": [((32, 3, 224, 224), "float32")]},
Expand All @@ -269,6 +330,10 @@ def test_pytorch_to_onnx():
assert "model" in onnx_function_run.outputs


@pytest.mark.skipif(
condition=not _validate_environment_variables(),
reason="Project's environment variables are not set",
)
def test_optimize_help():
"""
Test the 'optimize' handler, passing "help" in the 'optimizations'.
Expand All @@ -277,14 +342,14 @@ def test_optimize_help():
artifact_path = _setup_environment()

# Import the ONNX Utils function:
onnx_function = mlrun.import_function("function.yaml")
onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME)

# Run the function, passing "help" in 'optimizations' and see that no exception was raised:
is_test_passed = True
try:
onnx_function.run(
handler="optimize",
artifact_path=artifact_path,
output_path=artifact_path,
params={
"model_path": "",
"optimizations": "help",
Expand All @@ -303,6 +368,14 @@ def test_optimize_help():
assert is_test_passed


@pytest.mark.skipif(
condition=not _validate_environment_variables(),
reason="Project's environment variables are not set",
)
@pytest.mark.skipif(
condition=not _is_tf2onnx_available(),
reason="tf2onnx is not installed",
)
def test_optimize():
"""
Test the 'optimize' handler, giving it a model from the ONNX zoo git repository.
Expand All @@ -314,25 +387,26 @@ def test_optimize():
log_model_function = mlrun.code_to_function(
filename="test_onnx_utils.py",
name="log_model",
project=PROJECT_NAME,
kind="job",
image="mlrun/ml-models",
)

# Run the function to log the model:
log_model_run = log_model_function.run(
handler="_log_tf_keras_model",
artifact_path=artifact_path,
output_path=artifact_path,
params={"model_name": MODEL_NAME},
local=True,
)

# Import the ONNX Utils function:
onnx_function = mlrun.import_function("function.yaml")
onnx_function = mlrun.import_function("function.yaml", project=PROJECT_NAME)

# Run the function to convert our model to ONNX:
to_onnx_function_run = onnx_function.run(
handler="to_onnx",
artifact_path=artifact_path,
output_path=artifact_path,
params={
# Take the logged model from the previous function.
"model_path": log_model_run.status.artifacts[0]["spec"]["target_path"],
Expand All @@ -345,7 +419,7 @@ def test_optimize():
# Run the function to optimize our model:
optimize_function_run = onnx_function.run(
handler="optimize",
artifact_path=artifact_path,
output_path=artifact_path,
params={
# Take the logged model from the previous function.
"model_path": to_onnx_function_run.status.artifacts[0]["spec"][
Expand Down