diff --git a/native/spark-expr/src/datetime_funcs/date_diff.rs b/native/spark-expr/src/datetime_funcs/date_diff.rs index 6a593f0f87..11efdfdce1 100644 --- a/native/spark-expr/src/datetime_funcs/date_diff.rs +++ b/native/spark-expr/src/datetime_funcs/date_diff.rs @@ -71,9 +71,22 @@ impl ScalarUDFImpl for SparkDateDiff { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { let [end_date, start_date] = take_function_args(self.name(), args.args)?; - // Convert scalars to arrays for uniform processing - let end_arr = end_date.into_array(1)?; - let start_arr = start_date.into_array(1)?; + // Determine target length (broadcast scalars to column length) + let len = match (&end_date, &start_date) { + (ColumnarValue::Array(a), _) => a.len(), + (_, ColumnarValue::Array(a)) => a.len(), + _ => 1, + }; + + // Convert both arguments to arrays of the same length + let end_arr = end_date.into_array(len)?; + let start_arr = start_date.into_array(len)?; + + // Normalize dictionary-backed arrays (important for Parquet / Iceberg) + let end_arr = arrow::compute::cast(&end_arr, &DataType::Date32) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + let start_arr = arrow::compute::cast(&start_arr, &DataType::Date32) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; let end_date_array = end_arr .as_any() diff --git a/native/spark-expr/src/datetime_funcs/extract_date_part.rs b/native/spark-expr/src/datetime_funcs/extract_date_part.rs index acb7d2266e..0f93821d39 100644 --- a/native/spark-expr/src/datetime_funcs/extract_date_part.rs +++ b/native/spark-expr/src/datetime_funcs/extract_date_part.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::utils::array_with_timezone; use arrow::compute::{date_part, DatePart}; use arrow::datatypes::{DataType, TimeUnit::Microsecond}; use datafusion::common::{internal_datafusion_err, DataFusionError}; @@ -24,6 +23,8 @@ use datafusion::logical_expr::{ }; use std::{any::Any, fmt::Debug}; +use crate::utils::array_with_timezone; + macro_rules! extract_date_part { ($struct_name:ident, $fn_name:expr, $date_part_variant:ident) => { #[derive(Debug, PartialEq, Eq, Hash)] @@ -75,14 +76,28 @@ macro_rules! extract_date_part { match args { [ColumnarValue::Array(array)] => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; + let array = match array.data_type() { + // TimestampNTZ → DO NOT apply timezone conversion + DataType::Timestamp(_, None) => array.clone(), + + // Timestamp with timezone → convert from UTC to session timezone + DataType::Timestamp(_, Some(_)) => array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?, + + other => { + return Err(DataFusionError::Execution(format!( + "extract_date_part expects a Timestamp input, got {:?}", + other + ))); + } + }; + let result = date_part(&array, DatePart::$date_part_variant)?; Ok(ColumnarValue::Array(result)) } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala index bdb4a9d4b1..33918b424b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/ParquetDatetimeRebaseSuite.scala @@ -116,6 +116,38 @@ abstract class ParquetDatetimeRebaseSuite extends CometTestBase { } } + test("datediff works with dictionary-encoded timestamp columns") { + withSQLConf( + "spark.sql.parquet.enableDictionary" -> "true", + CometConf.COMET_ENABLED.key -> "true") { + val df = spark + .createDataFrame( + Seq( + ("a", java.sql.Timestamp.valueOf("2024-01-02 10:00:00")), + ("b", java.sql.Timestamp.valueOf("2024-01-03 11:00:00")))) + .toDF("id", "ts") + + withTempPath { path => + df.write.mode("overwrite").parquet(path.getAbsolutePath) + + val readDf = spark.read.parquet(path.getAbsolutePath) + + val result = readDf + .selectExpr("datediff(current_date(), ts)") + .collect() + + assert(result.length == 2) + } + + // This used to fail due to array length mismatch + val result = readDf + .selectExpr("datediff(current_date(), ts)") + .collect() + + assert(result.length == 2) + } + } + private def checkSparkNoRebaseAnswer(df: => DataFrame): Unit = { var expected: Array[Row] = Array.empty