diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b552a071d6..0d8013ffb6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -132,9 +132,108 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] { val arrayExprProto = exprToProto(expr.children.head, inputs, binding) val keyExprProto = exprToProto(expr.children(1), inputs, binding) - val arrayContainsScalarExpr = + // Check if array is null - if so, return null + val isArrayNotNullExpr = createUnaryExpr( + expr, + expr.children.head, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + + // Check if search value is null - if so, return null + val isKeyNotNullExpr = createUnaryExpr( + expr, + expr.children(1), + inputs, + binding, + (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + + // Check if value exists in array + val arrayHasValueExpr = scalarFunctionExprToProto("array_has", arrayExprProto, keyExprProto) - optExprWithInfo(arrayContainsScalarExpr, expr, expr.children: _*) + + // Check if array contains null elements (for three-valued logic) + val nullKeyLiteralProto = exprToProto(Literal(null, expr.children(1).dataType), Seq.empty) + val arrayHasNullExpr = + scalarFunctionExprToProto("array_has", arrayExprProto, nullKeyLiteralProto) + + // Build the three-valued logic: + // 1. If array is null -> return null + // 2. If key is null -> return null + // 3. If array_has(array, key) is true -> return true + // 4. If array_has(array, key) is false AND array_has(array, null) is true + // -> return null (indeterminate) + // 5. If array_has(array, key) is false AND array_has(array, null) is false + // -> return false + if (isArrayNotNullExpr.isDefined && isKeyNotNullExpr.isDefined && + arrayHasValueExpr.isDefined && arrayHasNullExpr.isDefined && + nullKeyLiteralProto.isDefined) { + // Create boolean literals + val trueLiteralProto = exprToProto(Literal(true, BooleanType), Seq.empty) + val falseLiteralProto = exprToProto(Literal(false, BooleanType), Seq.empty) + val nullBooleanLiteralProto = exprToProto(Literal(null, BooleanType), Seq.empty) + + if (trueLiteralProto.isDefined && falseLiteralProto.isDefined && + nullBooleanLiteralProto.isDefined) { + // If array_has(array, key) is false, check if array has nulls + // If array_has(array, null) is true -> return null, else return false + val whenNotFoundCheckNulls = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(arrayHasNullExpr.get) // if array has nulls + .addThen(nullBooleanLiteralProto.get) // return null (indeterminate) + .setElseExpr(falseLiteralProto.get) // else return false + .build() + + // If array_has(array, key) is true, return true, else check null case + val whenValueFound = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(arrayHasValueExpr.get) // if value found + .addThen(trueLiteralProto.get) // return true + .setElseExpr( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(whenNotFoundCheckNulls) + .build() + ) // else check null case + .build() + + // Check if key is null -> return null, else use the logic above + val whenKeyNotNull = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(isKeyNotNullExpr.get) // if key is not null + .addThen( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(whenValueFound) + .build()) + .setElseExpr(nullBooleanLiteralProto.get) // key is null -> return null + .build() + + // Outer case: if array is null, return null, else use the logic above + val outerCaseWhen = ExprOuterClass.CaseWhen + .newBuilder() + .addWhen(isArrayNotNullExpr.get) // if array is not null + .addThen( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(whenKeyNotNull) + .build()) + .setElseExpr(nullBooleanLiteralProto.get) // array is null -> return null + .build() + + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(outerCaseWhen) + .build()) + } else { + withInfo(expr, expr.children: _*) + None + } + } else { + withInfo(expr, expr.children: _*) + None + } } } diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index cf49117364..de1866a322 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -325,6 +325,48 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_contains - three-valued null logic") { + // Test Spark's three-valued logic for array_contains: + // 1. Returns true if value is found + // 2. Returns false if no match found AND no null elements exist + // 3. Returns null if no match found BUT null elements exist (indeterminate) + // 4. Returns null if search value is null + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, n = 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // Disable constant folding to ensure Comet implementation is exercised + withSQLConf( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + // Test case 1: value found -> returns true + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 2) FROM t1")) + + // Test case 2: no match, no nulls -> returns false + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, 2, 3), 5) FROM t1")) + + // Test case 3: no match, but null exists -> returns null (indeterminate) + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, null, 3), 2) FROM t1")) + + // Test case 4: match found even with nulls -> returns true + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, null, 3), 1) FROM t1")) + + // Test case 5: search value is null -> returns null + checkSparkAnswerAndOperator( + sql("SELECT array_contains(array(1, 2, 3), cast(null as int)) FROM t1")) + + // Test case 6: array with nulls, searching for existing value -> returns true + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(1, null, 3), 3) FROM t1")) + + // Test case 7: empty array -> returns false + checkSparkAnswerAndOperator(sql("SELECT array_contains(array(), 1) FROM t1")) + } + } + } + } + test("array_contains - test all types (convert from Parquet)") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet")