diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 760dc3570f..6647e01cc8 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateDiff, SparkDateTrunc, - SparkSizeFunc, SparkStringSpace, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateDiff, + SparkDateTrunc, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), + Arc::new(ScalarUDF::new_from_impl(SparkContains::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), diff --git a/native/spark-expr/src/string_funcs/contains.rs b/native/spark-expr/src/string_funcs/contains.rs new file mode 100644 index 0000000000..bc34ce9cba --- /dev/null +++ b/native/spark-expr/src/string_funcs/contains.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimized `contains` string function for Spark compatibility. +//! +//! Optimized for scalar pattern case by passing scalar directly to arrow_contains +//! instead of expanding to arrays like DataFusion's built-in contains. + +use arrow::array::{Array, ArrayRef, BooleanArray, StringArray}; +use arrow::compute::kernels::comparison::contains as arrow_contains; +use arrow::datatypes::DataType; +use datafusion::common::{exec_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark-optimized contains function. +/// Returns true if the first string argument contains the second string argument. +/// Optimized for scalar pattern constants. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkContains { + signature: Signature, +} + +impl Default for SparkContains { + fn default() -> Self { + Self::new() + } +} + +impl SparkContains { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return exec_err!("contains function requires exactly 2 arguments"); + } + spark_contains(&args.args[0], &args.args[1]) + } +} + +/// Execute contains function with optimized scalar pattern handling. +fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result { + match (haystack, needle) { + // Both arrays - use arrow's contains directly + (ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => { + let result = arrow_contains(haystack_array, needle_array)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Array haystack, scalar needle - OPTIMIZED PATH + (ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => { + let result = contains_with_arrow_scalar(haystack_array, needle_scalar)?; + Ok(ColumnarValue::Array(result)) + } + + // Scalar haystack, array needle - less common + (ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => { + let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?; + let result = arrow_contains(&haystack_array, needle_array)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Both scalars - compute single result + (ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => { + let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +/// Optimized contains for array haystack with scalar needle. +/// Uses Arrow's native scalar handling for better performance. +fn contains_with_arrow_scalar( + haystack_array: &ArrayRef, + needle_scalar: &ScalarValue, +) -> Result { + // Handle null needle + if needle_scalar.is_null() { + return Ok(Arc::new(BooleanArray::new_null(haystack_array.len()))); + } + + // Extract the needle string + let needle_str = match needle_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.clone(), + _ => { + return exec_err!( + "contains function requires string type for needle, got {:?}", + needle_scalar.data_type() + ) + } + }; + + // Create scalar array for needle - tells Arrow to use optimized paths + let needle_scalar_array = StringArray::new_scalar(needle_str); + + // Use Arrow's contains which detects scalar case and uses optimized paths + let result = arrow_contains(haystack_array, &needle_scalar_array)?; + Ok(Arc::new(result)) +} + +/// Contains for two scalar values. +fn contains_scalar_scalar( + haystack_scalar: &ScalarValue, + needle_scalar: &ScalarValue, +) -> Result { + // Handle nulls + if haystack_scalar.is_null() || needle_scalar.is_null() { + return Ok(ScalarValue::Boolean(None)); + } + + let haystack_str = match haystack_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for haystack, got {:?}", + haystack_scalar.data_type() + ) + } + }; + + let needle_str = match needle_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for needle, got {:?}", + needle_scalar.data_type() + ) + } + }; + + Ok(ScalarValue::Boolean(Some( + haystack_str.contains(needle_str), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_contains_array_scalar() { + let haystack = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("testing"), + None, + ])) as ArrayRef; + let needle = ScalarValue::Utf8(Some("world".to_string())); + + let result = contains_with_arrow_scalar(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + assert!(bool_array.value(0)); // "hello world" contains "world" + assert!(!bool_array.value(1)); // "foo bar" does not contain "world" + assert!(!bool_array.value(2)); // "testing" does not contain "world" + assert!(bool_array.is_null(3)); // null input => null output + } + + #[test] + fn test_contains_scalar_scalar() { + let haystack = ScalarValue::Utf8(Some("hello world".to_string())); + let needle = ScalarValue::Utf8(Some("world".to_string())); + + let result = contains_scalar_scalar(&haystack, &needle).unwrap(); + assert_eq!(result, ScalarValue::Boolean(Some(true))); + + let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string())); + let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap(); + assert_eq!(result, ScalarValue::Boolean(Some(false))); + } + + #[test] + fn test_contains_null_needle() { + let haystack = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + ])) as ArrayRef; + let needle = ScalarValue::Utf8(None); + + let result = contains_with_arrow_scalar(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + // Null needle should produce null results + assert!(bool_array.is_null(0)); + assert!(bool_array.is_null(1)); + } + + #[test] + fn test_contains_empty_needle() { + let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef; + let needle = ScalarValue::Utf8(Some("".to_string())); + + let result = contains_with_arrow_scalar(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + // Empty string is contained in any string + assert!(bool_array.value(0)); + assert!(bool_array.value(1)); + } +} diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..abdd0cc89b 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod contains; mod string_space; mod substring; +pub use contains::SparkContains; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index fe5ea77a89..b8a635e75b 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1165,7 +1165,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Filter rows that contains 'rose' in 'name' column val queryContains = sql(s"select id from $table where contains (name, 'rose')") - checkAnswer(queryContains, Row(5) :: Nil) + checkSparkAnswerAndOperator(queryContains) + + // Additional test cases for optimized contains implementation + // Test with empty pattern (should match all non-null rows) + val queryEmptyPattern = sql(s"select id from $table where contains (name, '')") + checkSparkAnswerAndOperator(queryEmptyPattern) + + // Test with pattern not found + val queryNotFound = sql(s"select id from $table where contains (name, 'xyz')") + checkSparkAnswerAndOperator(queryNotFound) + + // Test with pattern at start + val queryStart = sql(s"select id from $table where contains (name, 'James')") + checkSparkAnswerAndOperator(queryStart) + + // Test with pattern at end + val queryEnd = sql(s"select id from $table where contains (name, 'Smith')") + checkSparkAnswerAndOperator(queryEnd) } }