diff --git a/native/spark-expr/src/datetime_funcs/date_diff.rs b/native/spark-expr/src/datetime_funcs/date_diff.rs index 6a593f0f87..c39d2095e0 100644 --- a/native/spark-expr/src/datetime_funcs/date_diff.rs +++ b/native/spark-expr/src/datetime_funcs/date_diff.rs @@ -71,9 +71,22 @@ impl ScalarUDFImpl for SparkDateDiff { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [end_date, start_date] = take_function_args(self.name(), args.args)?; - // Convert scalars to arrays for uniform processing - let end_arr = end_date.into_array(1)?; - let start_arr = start_date.into_array(1)?; + // Determine target length (broadcast scalars to column length) + let len = match (&end_date, &start_date) { + (ColumnarValue::Array(a), _) => a.len(), + (_, ColumnarValue::Array(a)) => a.len(), + _ => 1, + }; + + // Convert both arguments to arrays of the same length + let end_arr = end_date.into_array(len)?; + let start_arr = start_date.into_array(len)?; + + // Normalize dictionary arrays (important for Iceberg) + let end_arr = arrow::compute::cast(&end_arr, &DataType::Date32) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + let start_arr = arrow::compute::cast(&start_arr, &DataType::Date32) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; let end_date_array = end_arr .as_any() @@ -97,8 +110,4 @@ impl ScalarUDFImpl for SparkDateDiff { Ok(ColumnarValue::Array(Arc::new(result))) } - - fn aliases(&self) -> &[String] { - &self.aliases - } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala index bdb4a9d4b1..50ed6844ce 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala @@ -116,6 +116,33 @@ abstract class ParquetDatetimeRebaseSuite extends CometTestBase { } } + test("datediff works with dictionary-encoded timestamp columns") { + withTempPath { path => + withSQLConf( + CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_COMET, + CometConf.COMET_ENABLED.key -> "true", + "spark.sql.parquet.enableDictionary" -> "true") { + val df = spark + .createDataFrame( + Seq( + ("a", java.sql.Timestamp.valueOf("2024-01-02 10:00:00")), + ("b", java.sql.Timestamp.valueOf("2024-01-03 11:00:00")))) + .toDF("id", "ts") + + df.write.mode("overwrite").parquet(path.getAbsolutePath) + + val readDf = spark.read.parquet(path.getAbsolutePath) + + val result = readDf + .selectExpr("datediff(current_date(), ts) as diff") + .collect() + + // Just verify it executes correctly (no CometNativeException) + assert(result.length == 2) + } + } + } + private def checkSparkNoRebaseAnswer(df: => DataFrame): Unit = { var expected: Array[Row] = Array.empty