diff --git a/native/core/src/execution/operators/iceberg_scan.rs b/native/core/src/execution/operators/iceberg_scan.rs index 2f639e9f70..bc20592e90 100644 --- a/native/core/src/execution/operators/iceberg_scan.rs +++ b/native/core/src/execution/operators/iceberg_scan.rs @@ -44,6 +44,7 @@ use crate::parquet::parquet_support::SparkParquetOptions; use crate::parquet::schema_adapter::SparkSchemaAdapterFactory; use datafusion::datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_comet_spark_expr::EvalMode; +use iceberg::scan::FileScanTask; /// Iceberg table scan operator that uses iceberg-rust to read Iceberg tables. /// @@ -58,8 +59,8 @@ pub struct IcebergScanExec { plan_properties: PlanProperties, /// Catalog-specific configuration for FileIO catalog_properties: HashMap, - /// Pre-planned file scan tasks, grouped by partition - file_task_groups: Vec>, + /// Pre-planned file scan tasks + tasks: Vec, /// Metrics metrics: ExecutionPlanMetricsSet, } @@ -69,11 +70,10 @@ impl IcebergScanExec { metadata_location: String, schema: SchemaRef, catalog_properties: HashMap, - file_task_groups: Vec>, + tasks: Vec, ) -> Result { let output_schema = schema; - let num_partitions = file_task_groups.len(); - let plan_properties = Self::compute_properties(Arc::clone(&output_schema), num_partitions); + let plan_properties = Self::compute_properties(Arc::clone(&output_schema), 1); let metrics = ExecutionPlanMetricsSet::new(); @@ -82,7 +82,7 @@ impl IcebergScanExec { output_schema, plan_properties, catalog_properties, - file_task_groups, + tasks, metrics, }) } @@ -127,19 +127,10 @@ impl ExecutionPlan for IcebergScanExec { fn execute( &self, - partition: usize, + _partition: usize, context: Arc, ) -> DFResult { - if partition < self.file_task_groups.len() { - let tasks = &self.file_task_groups[partition]; - self.execute_with_tasks(tasks.clone(), partition, context) - } else { - Err(DataFusionError::Execution(format!( - "IcebergScanExec: Partition index {} out of range (only {} task groups available)", - partition, - self.file_task_groups.len() - ))) - } + self.execute_with_tasks(self.tasks.clone(), context) } fn metrics(&self) -> Option { @@ -152,15 +143,14 @@ impl IcebergScanExec { /// deletes via iceberg-rust's ArrowReader. fn execute_with_tasks( &self, - tasks: Vec, - partition: usize, + tasks: Vec, context: Arc, ) -> DFResult { let output_schema = Arc::clone(&self.output_schema); let file_io = Self::load_file_io(&self.catalog_properties, &self.metadata_location)?; let batch_size = context.session_config().batch_size(); - let metrics = IcebergScanMetrics::new(&self.metrics, partition); + let metrics = IcebergScanMetrics::new(&self.metrics); let num_tasks = tasks.len(); metrics.num_splits.add(num_tasks); @@ -221,10 +211,10 @@ struct IcebergScanMetrics { } impl IcebergScanMetrics { - fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + fn new(metrics: &ExecutionPlanMetricsSet) -> Self { Self { - baseline: BaselineMetrics::new(metrics, partition), - num_splits: MetricBuilder::new(metrics).counter("num_splits", partition), + baseline: BaselineMetrics::new(metrics, 0), + num_splits: MetricBuilder::new(metrics).counter("num_splits", 0), } } } @@ -311,11 +301,11 @@ where impl DisplayAs for IcebergScanExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { - let num_tasks: usize = self.file_task_groups.iter().map(|g| g.len()).sum(); write!( f, "IcebergScanExec: metadata_location={}, num_tasks={}", - self.metadata_location, num_tasks + self.metadata_location, + self.tasks.len() ) } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 44ff20a44f..12db052394 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1132,33 +1132,28 @@ impl PhysicalPlanner { )) } OpStruct::IcebergScan(scan) => { - let required_schema: SchemaRef = - convert_spark_types_to_arrow_schema(scan.required_schema.as_slice()); + // Extract common data and single partition's file tasks + // Per-partition injection happens in Scala before sending to native + let common = scan + .common + .as_ref() + .ok_or_else(|| GeneralError("IcebergScan missing common data".into()))?; - let catalog_properties: HashMap = scan + let required_schema = + convert_spark_types_to_arrow_schema(common.required_schema.as_slice()); + let catalog_properties: HashMap = common .catalog_properties .iter() .map(|(k, v)| (k.clone(), v.clone())) .collect(); - - let metadata_location = scan.metadata_location.clone(); - - debug_assert!( - !scan.file_partitions.is_empty(), - "IcebergScan must have at least one file partition. This indicates a bug in Scala serialization." - ); - - let tasks = parse_file_scan_tasks( - scan, - &scan.file_partitions[self.partition as usize].file_scan_tasks, - )?; - let file_task_groups = vec![tasks]; + let metadata_location = common.metadata_location.clone(); + let tasks = parse_file_scan_tasks_from_common(common, &scan.file_scan_tasks)?; let iceberg_scan = IcebergScanExec::new( metadata_location, required_schema, catalog_properties, - file_task_groups, + tasks, )?; Ok(( @@ -2743,15 +2738,14 @@ fn partition_data_to_struct( /// Each task contains a residual predicate that is used for row-group level filtering /// during Parquet scanning. /// -/// This function uses deduplication pools from the IcebergScan to avoid redundant parsing -/// of schemas, partition specs, partition types, name mappings, and other repeated data. -fn parse_file_scan_tasks( - proto_scan: &spark_operator::IcebergScan, +/// This function uses deduplication pools from the IcebergScanCommon to avoid redundant +/// parsing of schemas, partition specs, partition types, name mappings, and other repeated data. +fn parse_file_scan_tasks_from_common( + proto_common: &spark_operator::IcebergScanCommon, proto_tasks: &[spark_operator::IcebergFileScanTask], ) -> Result, ExecutionError> { - // Build caches upfront: for 10K tasks with 1 schema, this parses the schema - // once instead of 10K times, eliminating redundant JSON deserialization - let schema_cache: Vec> = proto_scan + // Parse each unique schema once, not once per task + let schema_cache: Vec> = proto_common .schema_pool .iter() .map(|json| { @@ -2764,7 +2758,7 @@ fn parse_file_scan_tasks( }) .collect::, _>>()?; - let partition_spec_cache: Vec>> = proto_scan + let partition_spec_cache: Vec>> = proto_common .partition_spec_pool .iter() .map(|json| { @@ -2774,7 +2768,7 @@ fn parse_file_scan_tasks( }) .collect(); - let name_mapping_cache: Vec>> = proto_scan + let name_mapping_cache: Vec>> = proto_common .name_mapping_pool .iter() .map(|json| { @@ -2784,7 +2778,7 @@ fn parse_file_scan_tasks( }) .collect(); - let delete_files_cache: Vec> = proto_scan + let delete_files_cache: Vec> = proto_common .delete_files_pool .iter() .map(|list| { @@ -2796,7 +2790,7 @@ fn parse_file_scan_tasks( "EQUALITY_DELETES" => iceberg::spec::DataContentType::EqualityDeletes, other => { return Err(GeneralError(format!( - "Invalid delete content type '{}'. This indicates a bug in Scala serialization.", + "Invalid delete content type '{}'", other ))) } @@ -2817,7 +2811,6 @@ fn parse_file_scan_tasks( }) .collect::, _>>()?; - // Partition data pool is in protobuf messages let results: Result, _> = proto_tasks .iter() .map(|proto_task| { @@ -2851,7 +2844,7 @@ fn parse_file_scan_tasks( }; let bound_predicate = if let Some(idx) = proto_task.residual_idx { - proto_scan + proto_common .residual_pool .get(idx as usize) .and_then(convert_spark_expr_to_predicate) @@ -2871,24 +2864,22 @@ fn parse_file_scan_tasks( }; let partition = if let Some(partition_data_idx) = proto_task.partition_data_idx { - // Get partition data from protobuf pool - let partition_data_proto = proto_scan + let partition_data_proto = proto_common .partition_data_pool .get(partition_data_idx as usize) .ok_or_else(|| { ExecutionError::GeneralError(format!( "Invalid partition_data_idx: {} (pool size: {})", partition_data_idx, - proto_scan.partition_data_pool.len() + proto_common.partition_data_pool.len() )) })?; - // Convert protobuf PartitionData to iceberg Struct match partition_data_to_struct(partition_data_proto) { Ok(s) => Some(s), Err(e) => { return Err(ExecutionError::GeneralError(format!( - "Failed to deserialize partition data from protobuf: {}", + "Failed to deserialize partition data: {}", e ))) } @@ -2907,14 +2898,14 @@ fn parse_file_scan_tasks( .and_then(|idx| name_mapping_cache.get(idx as usize)) .and_then(|opt| opt.clone()); - let project_field_ids = proto_scan + let project_field_ids = proto_common .project_field_ids_pool .get(proto_task.project_field_ids_idx as usize) .ok_or_else(|| { ExecutionError::GeneralError(format!( "Invalid project_field_ids_idx: {} (pool size: {})", proto_task.project_field_ids_idx, - proto_scan.project_field_ids_pool.len() + proto_common.project_field_ids_pool.len() )) })? .field_ids diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index 6dfe546ac8..a55657b7af 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -34,6 +34,7 @@ pub mod spark_partitioning { // Include generated modules from .proto files. #[allow(missing_docs)] +#[allow(clippy::large_enum_variant)] pub mod spark_operator { include!(concat!("generated", "/spark.spark_operator.rs")); } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 73c087cf36..78f118e6db 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -156,28 +156,34 @@ message PartitionData { repeated PartitionValue values = 1; } -message IcebergScan { - // Schema to read - repeated SparkStructField required_schema = 1; - +// Common data shared by all partitions in split mode (sent once, captured in closure) +message IcebergScanCommon { // Catalog-specific configuration for FileIO (credentials, S3/GCS config, etc.) - map catalog_properties = 2; - - // Pre-planned file scan tasks grouped by Spark partition - repeated IcebergFilePartition file_partitions = 3; + map catalog_properties = 1; // Table metadata file path for FileIO initialization - string metadata_location = 4; + string metadata_location = 2; + + // Schema to read + repeated SparkStructField required_schema = 3; - // Deduplication pools - shared data referenced by index from tasks - repeated string schema_pool = 5; - repeated string partition_type_pool = 6; - repeated string partition_spec_pool = 7; - repeated string name_mapping_pool = 8; - repeated ProjectFieldIdList project_field_ids_pool = 9; - repeated PartitionData partition_data_pool = 10; - repeated DeleteFileList delete_files_pool = 11; - repeated spark.spark_expression.Expr residual_pool = 12; + // Deduplication pools (must contain all entries for cross-partition deduplication) + repeated string schema_pool = 4; + repeated string partition_type_pool = 5; + repeated string partition_spec_pool = 6; + repeated string name_mapping_pool = 7; + repeated ProjectFieldIdList project_field_ids_pool = 8; + repeated PartitionData partition_data_pool = 9; + repeated DeleteFileList delete_files_pool = 10; + repeated spark.spark_expression.Expr residual_pool = 11; +} + +message IcebergScan { + // Common data shared across partitions (pools, metadata, catalog props) + IcebergScanCommon common = 1; + + // Single partition's file scan tasks + repeated IcebergFileScanTask file_scan_tasks = 2; } // Helper message for deduplicating field ID lists @@ -190,11 +196,6 @@ message DeleteFileList { repeated IcebergDeleteFile delete_files = 1; } -// Groups FileScanTasks for a single Spark partition -message IcebergFilePartition { - repeated IcebergFileScanTask file_scan_tasks = 1; -} - // Iceberg FileScanTask containing data file, delete files, and residual filter message IcebergFileScanTask { // Data file path (e.g., s3://bucket/warehouse/db/table/data/00000-0-abc.parquet) diff --git a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala index 2d772063e4..c5b6554054 100644 --- a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala +++ b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala @@ -734,7 +734,7 @@ case class CometIcebergNativeScanMetadata( table: Any, metadataLocation: String, nameMapping: Option[String], - tasks: java.util.List[_], + @transient tasks: java.util.List[_], scanSchema: Any, tableSchema: Any, globalFieldIdMapping: Map[String, Int], diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 68a63b6ae8..a3b3208b02 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -28,12 +28,13 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, PlanExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, DynamicPruningExpression, Expression, GenericInternalRow, PlanExpression} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{sideBySide, ArrayBasedMapData, GenericArrayData, MetadataColumnHelper} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet.{CometBatchScanExec, CometScanExec} -import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} +import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan, SubqueryAdaptiveBroadcastExec} +import org.apache.spark.sql.execution.InSubqueryExec import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan @@ -51,11 +52,15 @@ import org.apache.comet.parquet.{CometParquetScan, Native, SupportsComet} import org.apache.comet.parquet.CometParquetUtils.{encryptionEnabled, isEncryptionConfigSupported} import org.apache.comet.serde.operator.CometNativeScan import org.apache.comet.shims.CometTypeShim +import org.apache.comet.shims.ShimSubqueryBroadcast /** * Spark physical optimizer rule for replacing Spark scans with Comet scans. */ -case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with CometTypeShim { +case class CometScanRule(session: SparkSession) + extends Rule[SparkPlan] + with CometTypeShim + with ShimSubqueryBroadcast { import CometScanRule._ @@ -327,10 +332,6 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com case _ if scanExec.scan.getClass.getName == "org.apache.iceberg.spark.source.SparkBatchQueryScan" => - if (scanExec.runtimeFilters.exists(isDynamicPruningFilter)) { - return withInfo(scanExec, "Dynamic Partition Pruning is not supported") - } - val fallbackReasons = new ListBuffer[String]() // Native Iceberg scan requires both configs to be enabled @@ -621,10 +622,47 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com !hasUnsupportedDeletes } + // Check that all DPP subqueries use InSubqueryExec which we know how to handle. + // Future Spark versions might introduce new subquery types we haven't tested. + val dppSubqueriesSupported = { + val unsupportedSubqueries = scanExec.runtimeFilters.collect { + case DynamicPruningExpression(e) if !e.isInstanceOf[InSubqueryExec] => + e.getClass.getSimpleName + } + // Check for multi-index DPP which we don't support yet. + // SPARK-46946 changed SubqueryAdaptiveBroadcastExec from index: Int to indices: Seq[Int] + // as a preparatory refactor for future features (Null Safe Equality DPP, multiple + // equality predicates). Currently indices always has one element, but future Spark + // versions might use multiple indices. + val multiIndexDpp = scanExec.runtimeFilters.exists { + case DynamicPruningExpression(e: InSubqueryExec) => + e.plan match { + case sab: SubqueryAdaptiveBroadcastExec => + getSubqueryBroadcastIndices(sab).length > 1 + case _ => false + } + case _ => false + } + if (unsupportedSubqueries.nonEmpty) { + fallbackReasons += + s"Unsupported DPP subquery types: ${unsupportedSubqueries.mkString(", ")}. " + + "CometIcebergNativeScanExec only supports InSubqueryExec for DPP" + false + } else if (multiIndexDpp) { + // See SPARK-46946 for context on multi-index DPP + fallbackReasons += + "Multi-index DPP (indices.length > 1) is not yet supported. " + + "See SPARK-46946 for context." + false + } else { + true + } + } + if (schemaSupported && fileIOCompatible && formatVersionSupported && allParquetFiles && allSupportedFilesystems && partitionTypesSupported && complexTypePredicatesSupported && transformFunctionsSupported && - deleteFileTypesSupported) { + deleteFileTypesSupported && dppSubqueriesSupported) { CometBatchScanExec( scanExec.clone().asInstanceOf[BatchScanExec], runtimeFilters = scanExec.runtimeFilters, diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala index 0ad82af8f8..957f621032 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometIcebergNativeScan.scala @@ -28,10 +28,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.comet.{CometBatchScanExec, CometNativeExec} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceRDD, DataSourceRDDPartition} import org.apache.spark.sql.types._ import org.apache.comet.ConfigEntry -import org.apache.comet.iceberg.IcebergReflection +import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata, IcebergReflection} import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.OperatorOuterClass.{Operator, SparkStructField} @@ -309,7 +310,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit contentScanTaskClass: Class[_], fileScanTaskClass: Class[_], taskBuilder: OperatorOuterClass.IcebergFileScanTask.Builder, - icebergScanBuilder: OperatorOuterClass.IcebergScan.Builder, + commonBuilder: OperatorOuterClass.IcebergScanCommon.Builder, partitionTypeToPoolIndex: mutable.HashMap[String, Int], partitionSpecToPoolIndex: mutable.HashMap[String, Int], partitionDataToPoolIndex: mutable.HashMap[String, Int]): Unit = { @@ -334,7 +335,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val specIdx = partitionSpecToPoolIndex.getOrElseUpdate( partitionSpecJson, { val idx = partitionSpecToPoolIndex.size - icebergScanBuilder.addPartitionSpecPool(partitionSpecJson) + commonBuilder.addPartitionSpecPool(partitionSpecJson) idx }) taskBuilder.setPartitionSpecIdx(specIdx) @@ -415,7 +416,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val typeIdx = partitionTypeToPoolIndex.getOrElseUpdate( partitionTypeJson, { val idx = partitionTypeToPoolIndex.size - icebergScanBuilder.addPartitionTypePool(partitionTypeJson) + commonBuilder.addPartitionTypePool(partitionTypeJson) idx }) taskBuilder.setPartitionTypeIdx(typeIdx) @@ -470,7 +471,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val partitionDataIdx = partitionDataToPoolIndex.getOrElseUpdate( partitionDataKey, { val idx = partitionDataToPoolIndex.size - icebergScanBuilder.addPartitionDataPool(partitionDataProto) + commonBuilder.addPartitionDataPool(partitionDataProto) idx }) taskBuilder.setPartitionDataIdx(partitionDataIdx) @@ -671,17 +672,59 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } /** - * Serializes a CometBatchScanExec wrapping an Iceberg SparkBatchQueryScan to protobuf. + * Converts a CometBatchScanExec to a minimal placeholder IcebergScan operator. * - * Uses pre-extracted metadata from CometScanRule to avoid redundant reflection operations. All - * reflection and validation was done during planning, so serialization failures here would - * indicate a programming error rather than an expected fallback condition. + * Returns a placeholder operator with only metadata_location for matching during partition + * injection. All other fields (catalog properties, required schema, pools, partition data) are + * set by serializePartitions() at execution time after DPP resolves. */ override def convert( scan: CometBatchScanExec, builder: Operator.Builder, childOp: Operator*): Option[OperatorOuterClass.Operator] = { + + val metadata = scan.nativeIcebergScanMetadata.getOrElse { + throw new IllegalStateException( + "Programming error: CometBatchScanExec.nativeIcebergScanMetadata is None. " + + "Metadata should have been extracted in CometScanRule.") + } + val icebergScanBuilder = OperatorOuterClass.IcebergScan.newBuilder() + val commonBuilder = OperatorOuterClass.IcebergScanCommon.newBuilder() + + // Only set metadata_location - used for matching in PlanDataInjector. + // All other fields (catalog_properties, required_schema, pools) are set by + // serializePartitions() at execution time, so setting them here would be wasted work. + commonBuilder.setMetadataLocation(metadata.metadataLocation) + + icebergScanBuilder.setCommon(commonBuilder.build()) + // partition field intentionally empty - will be populated at execution time + + builder.clearChildren() + Some(builder.setIcebergScan(icebergScanBuilder).build()) + } + + /** + * Serializes partitions from inputRDD at execution time. + * + * Called after doPrepare() has resolved DPP subqueries. Builds pools and per-partition data in + * one pass from the DPP-filtered partitions. + * + * @param scanExec + * The BatchScanExec whose inputRDD contains the DPP-filtered partitions + * @param output + * The output attributes for the scan + * @param metadata + * Pre-extracted Iceberg metadata from CometScanRule + * @return + * Tuple of (commonBytes, perPartitionBytes) for native execution + */ + def serializePartitions( + scanExec: BatchScanExec, + output: Seq[Attribute], + metadata: CometIcebergNativeScanMetadata): (Array[Byte], Array[Array[Byte]]) = { + + val commonBuilder = OperatorOuterClass.IcebergScanCommon.newBuilder() // Deduplication structures - map unique values to pool indices val schemaToPoolIndex = mutable.HashMap[AnyRef, Int]() @@ -689,300 +732,225 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val partitionSpecToPoolIndex = mutable.HashMap[String, Int]() val nameMappingToPoolIndex = mutable.HashMap[String, Int]() val projectFieldIdsToPoolIndex = mutable.HashMap[Seq[Int], Int]() - val partitionDataToPoolIndex = mutable.HashMap[String, Int]() // Base64 bytes -> pool index + val partitionDataToPoolIndex = mutable.HashMap[String, Int]() val deleteFilesToPoolIndex = mutable.HashMap[Seq[OperatorOuterClass.IcebergDeleteFile], Int]() val residualToPoolIndex = mutable.HashMap[Option[Expr], Int]() - var totalTasks = 0 + val perPartitionBuilders = mutable.ArrayBuffer[OperatorOuterClass.IcebergScan]() - // Get pre-extracted metadata from planning phase - // If metadata is None, this is a programming error - metadata should have been extracted - // in CometScanRule before creating CometBatchScanExec - val metadata = scan.nativeIcebergScanMetadata.getOrElse { - throw new IllegalStateException( - "Programming error: CometBatchScanExec.nativeIcebergScanMetadata is None. " + - "Metadata should have been extracted in CometScanRule.") - } - - // Use pre-extracted metadata (no reflection needed) - icebergScanBuilder.setMetadataLocation(metadata.metadataLocation) + var totalTasks = 0 + commonBuilder.setMetadataLocation(metadata.metadataLocation) metadata.catalogProperties.foreach { case (key, value) => - icebergScanBuilder.putCatalogProperties(key, value) + commonBuilder.putCatalogProperties(key, value) } - // Set required_schema from output - scan.output.foreach { attr => + output.foreach { attr => val field = SparkStructField .newBuilder() .setName(attr.name) .setNullable(attr.nullable) serializeDataType(attr.dataType).foreach(field.setDataType) - icebergScanBuilder.addRequiredSchema(field.build()) + commonBuilder.addRequiredSchema(field.build()) } - // Extract FileScanTasks from the InputPartitions in the RDD - try { - scan.wrapped.inputRDD match { - case rdd: org.apache.spark.sql.execution.datasources.v2.DataSourceRDD => - val partitions = rdd.partitions - partitions.foreach { partition => - val partitionBuilder = OperatorOuterClass.IcebergFilePartition.newBuilder() + // Load Iceberg classes once (avoid repeated class loading in loop) + // scalastyle:off classforname + val contentScanTaskClass = Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK) + val fileScanTaskClass = Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK) + val contentFileClass = Class.forName(IcebergReflection.ClassNames.CONTENT_FILE) + val schemaParserClass = Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER) + val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA) + // scalastyle:on classforname - val inputPartitions = partition - .asInstanceOf[org.apache.spark.sql.execution.datasources.v2.DataSourceRDDPartition] - .inputPartitions + // Cache method lookups (avoid repeated getMethod in loop) + val fileMethod = contentScanTaskClass.getMethod("file") + val startMethod = contentScanTaskClass.getMethod("start") + val lengthMethod = contentScanTaskClass.getMethod("length") + val residualMethod = contentScanTaskClass.getMethod("residual") + val taskSchemaMethod = fileScanTaskClass.getMethod("schema") + val toJsonMethod = schemaParserClass.getMethod("toJson", schemaClass) + toJsonMethod.setAccessible(true) + + // Access inputRDD - safe now, DPP is resolved + scanExec.inputRDD match { + case rdd: DataSourceRDD => + val partitions = rdd.partitions + partitions.foreach { partition => + val partitionBuilder = OperatorOuterClass.IcebergScan.newBuilder() + + val inputPartitions = partition + .asInstanceOf[DataSourceRDDPartition] + .inputPartitions + + inputPartitions.foreach { inputPartition => + val inputPartClass = inputPartition.getClass - inputPartitions.foreach { inputPartition => - val inputPartClass = inputPartition.getClass + try { + val taskGroupMethod = inputPartClass.getDeclaredMethod("taskGroup") + taskGroupMethod.setAccessible(true) + val taskGroup = taskGroupMethod.invoke(inputPartition) - try { - val taskGroupMethod = inputPartClass.getDeclaredMethod("taskGroup") - taskGroupMethod.setAccessible(true) - val taskGroup = taskGroupMethod.invoke(inputPartition) + val taskGroupClass = taskGroup.getClass + val tasksMethod = taskGroupClass.getMethod("tasks") + val tasksCollection = + tasksMethod.invoke(taskGroup).asInstanceOf[java.util.Collection[_]] - val taskGroupClass = taskGroup.getClass - val tasksMethod = taskGroupClass.getMethod("tasks") - val tasksCollection = - tasksMethod.invoke(taskGroup).asInstanceOf[java.util.Collection[_]] + tasksCollection.asScala.foreach { task => + totalTasks += 1 - tasksCollection.asScala.foreach { task => - totalTasks += 1 + val taskBuilder = OperatorOuterClass.IcebergFileScanTask.newBuilder() - try { - val taskBuilder = OperatorOuterClass.IcebergFileScanTask.newBuilder() - - // scalastyle:off classforname - val contentScanTaskClass = - Class.forName(IcebergReflection.ClassNames.CONTENT_SCAN_TASK) - val fileScanTaskClass = - Class.forName(IcebergReflection.ClassNames.FILE_SCAN_TASK) - val contentFileClass = - Class.forName(IcebergReflection.ClassNames.CONTENT_FILE) - // scalastyle:on classforname - - val fileMethod = contentScanTaskClass.getMethod("file") - val dataFile = fileMethod.invoke(task) - - val filePathOpt = - IcebergReflection.extractFileLocation(contentFileClass, dataFile) - - filePathOpt match { - case Some(filePath) => - taskBuilder.setDataFilePath(filePath) - case None => - val msg = - "Iceberg reflection failure: Cannot extract file path from data file" - logError(msg) - throw new RuntimeException(msg) - } + val dataFile = fileMethod.invoke(task) - val startMethod = contentScanTaskClass.getMethod("start") - val start = startMethod.invoke(task).asInstanceOf[Long] - taskBuilder.setStart(start) - - val lengthMethod = contentScanTaskClass.getMethod("length") - val length = lengthMethod.invoke(task).asInstanceOf[Long] - taskBuilder.setLength(length) - - try { - // Equality deletes require the full table schema to resolve field IDs, - // even for columns not in the projection. Schema evolution requires - // using the snapshot's schema to correctly read old data files. - // These requirements conflict, so we choose based on delete presence. - - val taskSchemaMethod = fileScanTaskClass.getMethod("schema") - val taskSchema = taskSchemaMethod.invoke(task) - - val deletes = - IcebergReflection.getDeleteFilesFromTask(task, fileScanTaskClass) - val hasDeletes = !deletes.isEmpty - - // Schema to pass to iceberg-rust's FileScanTask. - // This is used by RecordBatchTransformer for field type lookups (e.g., in - // constants_map) and default value generation. The actual projection is - // controlled by project_field_ids. - // - // Schema selection logic: - // 1. If hasDeletes=true: Use taskSchema (file-specific schema) because - // delete files reference specific schema versions and we need exact schema - // matching for MOR. - // 2. Else if scanSchema contains columns not in tableSchema: Use scanSchema - // because this is a VERSION AS OF query reading a historical snapshot with - // different schema (e.g., after column drop, scanSchema has old columns - // that tableSchema doesn't) - // 3. Else: Use tableSchema because scanSchema is the query OUTPUT schema - // (e.g., for aggregates like "SELECT count(*)", scanSchema only has - // aggregate fields and doesn't contain partition columns needed by - // constants_map) - val schema: AnyRef = - if (hasDeletes) { - taskSchema - } else { - // Check if scanSchema has columns that tableSchema doesn't have - // (VERSION AS OF case) - val scanSchemaFieldIds = IcebergReflection - .buildFieldIdMapping(metadata.scanSchema) - .values - .toSet - val tableSchemaFieldIds = IcebergReflection - .buildFieldIdMapping(metadata.tableSchema) - .values - .toSet - val hasHistoricalColumns = - scanSchemaFieldIds.exists(id => !tableSchemaFieldIds.contains(id)) - - if (hasHistoricalColumns) { - // VERSION AS OF: scanSchema has columns that current table doesn't have - metadata.scanSchema.asInstanceOf[AnyRef] - } else { - // Regular query: use tableSchema for partition field lookups - metadata.tableSchema.asInstanceOf[AnyRef] - } - } - - // scalastyle:off classforname - val schemaParserClass = - Class.forName(IcebergReflection.ClassNames.SCHEMA_PARSER) - val schemaClass = Class.forName(IcebergReflection.ClassNames.SCHEMA) - // scalastyle:on classforname - val toJsonMethod = schemaParserClass.getMethod("toJson", schemaClass) - toJsonMethod.setAccessible(true) - - // Use object identity for deduplication: Iceberg Schema objects are immutable - // and reused across tasks, making identity-based deduplication safe - val schemaIdx = schemaToPoolIndex.getOrElseUpdate( - schema, { - val idx = schemaToPoolIndex.size - val schemaJson = toJsonMethod.invoke(null, schema).asInstanceOf[String] - icebergScanBuilder.addSchemaPool(schemaJson) - idx - }) - taskBuilder.setSchemaIdx(schemaIdx) - - // Build field ID mapping from the schema we're using - val nameToFieldId = IcebergReflection.buildFieldIdMapping(schema) - - // Extract project_field_ids for scan.output columns. - // For schema evolution: try task schema first, then fall back to - // global scan schema (pre-extracted in metadata). - val projectFieldIds = scan.output.flatMap { attr => - nameToFieldId - .get(attr.name) - .orElse(metadata.globalFieldIdMapping.get(attr.name)) - .orElse { - logWarning( - s"Column '${attr.name}' not found in task or scan schema," + - "skipping projection") - None - } - } - - // Deduplicate project field IDs - val projectFieldIdsIdx = projectFieldIdsToPoolIndex.getOrElseUpdate( - projectFieldIds, { - val idx = projectFieldIdsToPoolIndex.size - val listBuilder = OperatorOuterClass.ProjectFieldIdList.newBuilder() - projectFieldIds.foreach(id => listBuilder.addFieldIds(id)) - icebergScanBuilder.addProjectFieldIdsPool(listBuilder.build()) - idx - }) - taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx) - } catch { - case e: Exception => - val msg = - "Iceberg reflection failure: " + - "Failed to extract schema from FileScanTask: " + - s"${e.getMessage}" - logError(msg) - throw new RuntimeException(msg, e) - } + val filePathOpt = + IcebergReflection.extractFileLocation(contentFileClass, dataFile) - // Deduplicate delete files - val deleteFilesList = - extractDeleteFilesList(task, contentFileClass, fileScanTaskClass) - if (deleteFilesList.nonEmpty) { - val deleteFilesIdx = deleteFilesToPoolIndex.getOrElseUpdate( - deleteFilesList, { - val idx = deleteFilesToPoolIndex.size - val listBuilder = OperatorOuterClass.DeleteFileList.newBuilder() - deleteFilesList.foreach(df => listBuilder.addDeleteFiles(df)) - icebergScanBuilder.addDeleteFilesPool(listBuilder.build()) - idx - }) - taskBuilder.setDeleteFilesIdx(deleteFilesIdx) - } + filePathOpt match { + case Some(filePath) => + taskBuilder.setDataFilePath(filePath) + case None => + val msg = + "Iceberg reflection failure: Cannot extract file path from data file" + logError(msg) + throw new RuntimeException(msg) + } - // Extract and deduplicate residual expression - val residualExprOpt = - try { - val residualMethod = contentScanTaskClass.getMethod("residual") - val residualExpr = residualMethod.invoke(task) - - val catalystExpr = convertIcebergExpression(residualExpr, scan.output) - - catalystExpr.flatMap { expr => - exprToProto(expr, scan.output, binding = false) - } - } catch { - case e: Exception => - logWarning( - "Failed to extract residual expression from FileScanTask: " + - s"${e.getMessage}") - None - } - - residualExprOpt.foreach { residualExpr => - val residualIdx = residualToPoolIndex.getOrElseUpdate( - Some(residualExpr), { - val idx = residualToPoolIndex.size - icebergScanBuilder.addResidualPool(residualExpr) - idx - }) - taskBuilder.setResidualIdx(residualIdx) + val start = startMethod.invoke(task).asInstanceOf[Long] + taskBuilder.setStart(start) + + val length = lengthMethod.invoke(task).asInstanceOf[Long] + taskBuilder.setLength(length) + + val taskSchema = taskSchemaMethod.invoke(task) + + val deletes = + IcebergReflection.getDeleteFilesFromTask(task, fileScanTaskClass) + val hasDeletes = !deletes.isEmpty + + val schema: AnyRef = + if (hasDeletes) { + taskSchema + } else { + val scanSchemaFieldIds = IcebergReflection + .buildFieldIdMapping(metadata.scanSchema) + .values + .toSet + val tableSchemaFieldIds = IcebergReflection + .buildFieldIdMapping(metadata.tableSchema) + .values + .toSet + val hasHistoricalColumns = + scanSchemaFieldIds.exists(id => !tableSchemaFieldIds.contains(id)) + + if (hasHistoricalColumns) { + metadata.scanSchema.asInstanceOf[AnyRef] + } else { + metadata.tableSchema.asInstanceOf[AnyRef] } + } - // Serialize partition spec and data (field definitions, transforms, values) - serializePartitionData( - task, - contentScanTaskClass, - fileScanTaskClass, - taskBuilder, - icebergScanBuilder, - partitionTypeToPoolIndex, - partitionSpecToPoolIndex, - partitionDataToPoolIndex) - - // Deduplicate name mapping - metadata.nameMapping.foreach { nm => - val nmIdx = nameMappingToPoolIndex.getOrElseUpdate( - nm, { - val idx = nameMappingToPoolIndex.size - icebergScanBuilder.addNameMappingPool(nm) - idx - }) - taskBuilder.setNameMappingIdx(nmIdx) + val schemaIdx = schemaToPoolIndex.getOrElseUpdate( + schema, { + val idx = schemaToPoolIndex.size + val schemaJson = toJsonMethod.invoke(null, schema).asInstanceOf[String] + commonBuilder.addSchemaPool(schemaJson) + idx + }) + taskBuilder.setSchemaIdx(schemaIdx) + + val nameToFieldId = IcebergReflection.buildFieldIdMapping(schema) + + val projectFieldIds = output.flatMap { attr => + nameToFieldId + .get(attr.name) + .orElse(metadata.globalFieldIdMapping.get(attr.name)) + .orElse { + logWarning(s"Column '${attr.name}' not found in task or scan schema, " + + "skipping projection") + None } + } - partitionBuilder.addFileScanTasks(taskBuilder.build()) + val projectFieldIdsIdx = projectFieldIdsToPoolIndex.getOrElseUpdate( + projectFieldIds, { + val idx = projectFieldIdsToPoolIndex.size + val listBuilder = OperatorOuterClass.ProjectFieldIdList.newBuilder() + projectFieldIds.foreach(id => listBuilder.addFieldIds(id)) + commonBuilder.addProjectFieldIdsPool(listBuilder.build()) + idx + }) + taskBuilder.setProjectFieldIdsIdx(projectFieldIdsIdx) + + val deleteFilesList = + extractDeleteFilesList(task, contentFileClass, fileScanTaskClass) + if (deleteFilesList.nonEmpty) { + val deleteFilesIdx = deleteFilesToPoolIndex.getOrElseUpdate( + deleteFilesList, { + val idx = deleteFilesToPoolIndex.size + val listBuilder = OperatorOuterClass.DeleteFileList.newBuilder() + deleteFilesList.foreach(df => listBuilder.addDeleteFiles(df)) + commonBuilder.addDeleteFilesPool(listBuilder.build()) + idx + }) + taskBuilder.setDeleteFilesIdx(deleteFilesIdx) + } + + val residualExprOpt = + try { + val residualExpr = residualMethod.invoke(task) + val catalystExpr = convertIcebergExpression(residualExpr, output) + catalystExpr.flatMap { expr => + exprToProto(expr, output, binding = false) + } + } catch { + case e: Exception => + logWarning( + "Failed to extract residual expression from FileScanTask: " + + s"${e.getMessage}") + None } + + residualExprOpt.foreach { residualExpr => + val residualIdx = residualToPoolIndex.getOrElseUpdate( + Some(residualExpr), { + val idx = residualToPoolIndex.size + commonBuilder.addResidualPool(residualExpr) + idx + }) + taskBuilder.setResidualIdx(residualIdx) + } + + serializePartitionData( + task, + contentScanTaskClass, + fileScanTaskClass, + taskBuilder, + commonBuilder, + partitionTypeToPoolIndex, + partitionSpecToPoolIndex, + partitionDataToPoolIndex) + + metadata.nameMapping.foreach { nm => + val nmIdx = nameMappingToPoolIndex.getOrElseUpdate( + nm, { + val idx = nameMappingToPoolIndex.size + commonBuilder.addNameMappingPool(nm) + idx + }) + taskBuilder.setNameMappingIdx(nmIdx) } + + partitionBuilder.addFileScanTasks(taskBuilder.build()) } } - - val builtPartition = partitionBuilder.build() - icebergScanBuilder.addFilePartitions(builtPartition) } - case _ => - } - } catch { - case e: Exception => - // CometScanRule already validated this scan should use native execution. - // Failure here is a programming error, not a graceful fallback scenario. - throw new IllegalStateException( - s"Native Iceberg scan serialization failed unexpectedly: ${e.getMessage}", - e) + + perPartitionBuilders += partitionBuilder.build() + } + case _ => + throw new IllegalStateException("Expected DataSourceRDD from BatchScanExec") } // Log deduplication summary @@ -999,7 +967,6 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit val avgDedup = if (totalTasks == 0) { "0.0" } else { - // Filter out empty pools - they shouldn't count as 100% dedup val nonEmptyPools = allPoolSizes.filter(_ > 0) if (nonEmptyPools.isEmpty) { "0.0" @@ -1009,8 +976,7 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit } } - // Calculate partition data pool size in bytes (protobuf format) - val partitionDataPoolBytes = icebergScanBuilder.getPartitionDataPoolList.asScala + val partitionDataPoolBytes = commonBuilder.getPartitionDataPoolList.asScala .map(_.getSerializedSize) .sum @@ -1021,8 +987,10 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit s"$partitionDataPoolBytes bytes (protobuf)") } - builder.clearChildren() - Some(builder.setIcebergScan(icebergScanBuilder).build()) + val commonBytes = commonBuilder.build().toByteArray + val perPartitionBytes = perPartitionBuilders.map(_.toByteArray).toArray + + (commonBytes, perPartitionBytes) } override def createExec(nativeOp: Operator, op: CometBatchScanExec): CometNativeExec = { @@ -1035,10 +1003,11 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit "Metadata should have been extracted in CometScanRule.") } - // Extract metadataLocation from the native operator - val metadataLocation = nativeOp.getIcebergScan.getMetadataLocation + // Extract metadataLocation from the native operator's common data + val metadataLocation = nativeOp.getIcebergScan.getCommon.getMetadataLocation - // Create the CometIcebergNativeScanExec using the companion object's apply method + // Pass BatchScanExec reference for deferred serialization (DPP support) + // Serialization happens at execution time after doPrepare() resolves DPP subqueries CometIcebergNativeScanExec(nativeOp, op.wrapped, op.session, metadataLocation, metadata) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 2fd7f12c24..63a67e82f2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -19,39 +19,204 @@ package org.apache.spark.sql.comet -import org.apache.spark.{Partition, SparkContext, TaskContext} -import org.apache.spark.rdd.{RDD, RDDOperationScope} +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.util.SerializableConfiguration + +import org.apache.comet.CometExecIterator +import org.apache.comet.serde.OperatorOuterClass + +/** + * Partition that carries per-partition planning data, avoiding closure capture of all partitions. + */ +private[spark] class CometExecPartition( + override val index: Int, + val inputPartitions: Array[Partition], + val planDataByKey: Map[String, Array[Byte]]) + extends Partition /** - * A RDD that executes Spark SQL query in Comet native execution to generate ColumnarBatch. + * Unified RDD for Comet native execution. + * + * Solves the closure capture problem: instead of capturing all partitions' data in the closure + * (which gets serialized to every task), each Partition object carries only its own data. + * + * Handles three cases: + * - With inputs + per-partition data: injects planning data into operator tree + * - With inputs + no per-partition data: just zips inputs (no injection overhead) + * - No inputs: uses numPartitions to create partitions + * + * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in + * CometIcebergNativeScanExec.serializedPartitionData before this RDD is created. It also handles + * ScalarSubquery expressions by registering them with CometScalarSubquery before execution. */ private[spark] class CometExecRDD( sc: SparkContext, - partitionNum: Int, - var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) - extends RDD[ColumnarBatch](sc, Nil) { + inputRDDs: Seq[RDD[ColumnarBatch]], + commonByKey: Map[String, Array[Byte]], + @transient perPartitionByKey: Map[String, Array[Array[Byte]]], + serializedPlan: Array[Byte], + defaultNumPartitions: Int, + numOutputCols: Int, + nativeMetrics: CometMetricNode, + subqueries: Seq[ScalarSubquery], + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, + encryptedFilePaths: Seq[String] = Seq.empty) + extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { - override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - f(Seq.empty, partitionNum, s.index) + // Determine partition count: from inputs if available, otherwise from parameter + private val numPartitions: Int = if (inputRDDs.nonEmpty) { + inputRDDs.head.partitions.length + } else if (perPartitionByKey.nonEmpty) { + perPartitionByKey.values.head.length + } else { + defaultNumPartitions } + // Validate all per-partition arrays have the same length to prevent + // ArrayIndexOutOfBoundsException in getPartitions (e.g., from broadcast scans with + // different partition counts after DPP filtering) + require( + perPartitionByKey.values.forall(_.length == numPartitions), + s"All per-partition arrays must have length $numPartitions, but found: " + + perPartitionByKey.map { case (key, arr) => s"$key -> ${arr.length}" }.mkString(", ")) + override protected def getPartitions: Array[Partition] = { - Array.tabulate(partitionNum)(i => - new Partition { - override def index: Int = i - }) + (0 until numPartitions).map { idx => + val inputParts = inputRDDs.map(_.partitions(idx)).toArray + val planData = perPartitionByKey.map { case (key, arr) => key -> arr(idx) } + new CometExecPartition(idx, inputParts, planData) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val partition = split.asInstanceOf[CometExecPartition] + + val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => + rdd.iterator(part, context) + } + + // Only inject if we have per-partition planning data + val actualPlan = if (commonByKey.nonEmpty) { + val basePlan = OperatorOuterClass.Operator.parseFrom(serializedPlan) + val injected = + PlanDataInjector.injectPlanData(basePlan, commonByKey, partition.planDataByKey) + PlanDataInjector.serializeOperator(injected) + } else { + serializedPlan + } + + val it = new CometExecIterator( + CometExec.newIterId, + inputs, + numOutputCols, + actualPlan, + nativeMetrics, + numPartitions, + partition.index, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) + + // Register ScalarSubqueries so native code can look them up + subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) + + Option(context).foreach { ctx => + ctx.addTaskCompletionListener[Unit] { _ => + it.close() + subqueries.foreach(sub => CometScalarSubquery.removeSubquery(it.id, sub)) + } + } + + it + } + + // Duplicates logic from Spark's ZippedPartitionsBaseRDD.getPreferredLocations + override def getPreferredLocations(split: Partition): Seq[String] = { + if (inputRDDs.isEmpty) return Nil + + val idx = split.index + val prefs = inputRDDs.map(rdd => rdd.preferredLocations(rdd.partitions(idx))) + // Prefer nodes where all inputs are local; fall back to any input's preferred location + val intersection = prefs.reduce((a, b) => a.intersect(b)) + if (intersection.nonEmpty) intersection else prefs.flatten.distinct } } object CometExecRDD { - def apply(sc: SparkContext, partitionNum: Int)( - f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) - : RDD[ColumnarBatch] = - withScope(sc) { - new CometExecRDD(sc, partitionNum, f) - } - private[spark] def withScope[U](sc: SparkContext)(body: => U): U = - RDDOperationScope.withScope[U](sc)(body) + /** + * Creates an RDD for standalone Iceberg scan (no parent native operators). + */ + def apply( + sc: SparkContext, + commonData: Array[Byte], + perPartitionData: Array[Array[Byte]], + numOutputCols: Int, + nativeMetrics: CometMetricNode): CometExecRDD = { + + // Standalone mode needs a placeholder plan for PlanDataInjector to fill in. + // PlanDataInjector correlates common/partition data by key (metadata_location for Iceberg). + val common = OperatorOuterClass.IcebergScanCommon.parseFrom(commonData) + val metadataLocation = common.getMetadataLocation + + val placeholderCommon = OperatorOuterClass.IcebergScanCommon + .newBuilder() + .setMetadataLocation(metadataLocation) + .build() + val placeholderScan = OperatorOuterClass.IcebergScan + .newBuilder() + .setCommon(placeholderCommon) + .build() + val placeholderPlan = OperatorOuterClass.Operator + .newBuilder() + .setIcebergScan(placeholderScan) + .build() + .toByteArray + + new CometExecRDD( + sc, + inputRDDs = Seq.empty, + commonByKey = Map(metadataLocation -> commonData), + perPartitionByKey = Map(metadataLocation -> perPartitionData), + serializedPlan = placeholderPlan, + defaultNumPartitions = perPartitionData.length, + numOutputCols = numOutputCols, + nativeMetrics = nativeMetrics, + subqueries = Seq.empty) + } + + /** + * Creates an RDD for native execution with optional per-partition planning data. + */ + // scalastyle:off + def apply( + sc: SparkContext, + inputRDDs: Seq[RDD[ColumnarBatch]], + commonByKey: Map[String, Array[Byte]], + perPartitionByKey: Map[String, Array[Array[Byte]]], + serializedPlan: Array[Byte], + numPartitions: Int, + numOutputCols: Int, + nativeMetrics: CometMetricNode, + subqueries: Seq[ScalarSubquery], + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, + encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = { + // scalastyle:on + + new CometExecRDD( + sc, + inputRDDs, + commonByKey, + perPartitionByKey, + serializedPlan, + numPartitions, + numOutputCols, + nativeMetrics, + subqueries, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala index 223ae4fbb7..207d8555f0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometIcebergNativeScanExec.scala @@ -21,18 +21,23 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ +import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, DynamicPruningExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.{InSubqueryExec, SubqueryAdaptiveBroadcastExec} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.AccumulatorV2 import com.google.common.base.Objects import org.apache.comet.iceberg.CometIcebergNativeScanMetadata import org.apache.comet.serde.OperatorOuterClass.Operator +import org.apache.comet.serde.operator.CometIcebergNativeScan +import org.apache.comet.shims.ShimSubqueryBroadcast /** * Native Iceberg scan operator that delegates file reading to iceberg-rust. @@ -41,6 +46,10 @@ import org.apache.comet.serde.OperatorOuterClass.Operator * execution. Iceberg's catalog and planning run in Spark to produce FileScanTasks, which are * serialized to protobuf for the native side to execute using iceberg-rust's FileIO and * ArrowReader. This provides better performance than reading through Spark's abstraction layers. + * + * Supports Dynamic Partition Pruning (DPP) by deferring partition serialization to execution + * time. The doPrepare() method waits for DPP subqueries to resolve, then lazy + * serializedPartitionData serializes the DPP-filtered partitions from inputRDD. */ case class CometIcebergNativeScanExec( override val nativeOp: Operator, @@ -48,16 +57,128 @@ case class CometIcebergNativeScanExec( @transient override val originalPlan: BatchScanExec, override val serializedPlanOpt: SerializedPlan, metadataLocation: String, - numPartitions: Int, @transient nativeIcebergScanMetadata: CometIcebergNativeScanMetadata) - extends CometLeafExec { + extends CometLeafExec + with ShimSubqueryBroadcast { override val supportsColumnar: Boolean = true override val nodeName: String = "CometIcebergNativeScan" - override lazy val outputPartitioning: Partitioning = - UnknownPartitioning(numPartitions) + /** + * Prepare DPP subquery plans. Called by Spark's prepare() before doExecuteColumnar(). Only + * kicks off async work - doesn't wait for results (that happens in serializedPartitionData). + */ + override protected def doPrepare(): Unit = { + originalPlan.runtimeFilters.foreach { + case DynamicPruningExpression(e: InSubqueryExec) => + e.plan.prepare() + case _ => + } + super.doPrepare() + } + + /** + * Lazy partition serialization - computed after doPrepare() resolves DPP. + * + * DPP (Dynamic Partition Pruning) Flow: + * + * {{{ + * Planning time: + * CometIcebergNativeScanExec created + * - serializedPartitionData not evaluated (lazy) + * - No partition serialization yet + * + * Execution time: + * 1. Spark calls prepare() on the plan tree + * - doPrepare() calls e.plan.prepare() for each DPP filter + * - Broadcast exchange starts materializing + * + * 2. Spark calls doExecuteColumnar() + * - Accesses perPartitionData + * - Forces serializedPartitionData evaluation (here) + * - Waits for DPP values (updateResult or reflection) + * - Calls serializePartitions with DPP-filtered inputRDD + * - Only matching partitions are serialized + * }}} + */ + @transient private lazy val serializedPartitionData: (Array[Byte], Array[Array[Byte]]) = { + // Ensure DPP subqueries are resolved before accessing inputRDD. + originalPlan.runtimeFilters.foreach { + case DynamicPruningExpression(e: InSubqueryExec) if e.values().isEmpty => + e.plan match { + case sab: SubqueryAdaptiveBroadcastExec => + // SubqueryAdaptiveBroadcastExec.executeCollect() throws, so we call + // child.executeCollect() directly. We use the index from SAB to find the + // right buildKey, then locate that key's column in child.output. + val rows = sab.child.executeCollect() + val indices = getSubqueryBroadcastIndices(sab) + + // SPARK-46946 changed index: Int to indices: Seq[Int] as a preparatory refactor + // for future features (Null Safe Equality DPP, multiple equality predicates). + // Currently indices always has one element. CometScanRule checks for multi-index + // DPP and falls back, so this assertion should never fail. + assert( + indices.length == 1, + s"Multi-index DPP not supported: indices=$indices. See SPARK-46946.") + val buildKeyIndex = indices.head + val buildKey = sab.buildKeys(buildKeyIndex) + + // Find column index in child.output by matching buildKey's exprId + val colIndex = buildKey match { + case attr: Attribute => + sab.child.output.indexWhere(_.exprId == attr.exprId) + // DPP may cast partition column to match join key type + case Cast(attr: Attribute, _, _, _) => + sab.child.output.indexWhere(_.exprId == attr.exprId) + case _ => buildKeyIndex + } + if (colIndex < 0) { + throw new IllegalStateException( + s"DPP build key '$buildKey' not found in ${sab.child.output.map(_.name)}") + } + + setInSubqueryResult(e, rows.map(_.get(colIndex, e.child.dataType))) + case _ => + e.updateResult() + } + case _ => + } + + CometIcebergNativeScan.serializePartitions(originalPlan, output, nativeIcebergScanMetadata) + } + + /** + * Sets InSubqueryExec's private result field via reflection. + * + * Reflection is required because: + * - SubqueryAdaptiveBroadcastExec.executeCollect() throws UnsupportedOperationException + * - InSubqueryExec has no public setter for result, only updateResult() which calls + * executeCollect() + * - We can't replace e.plan since it's a val + */ + private def setInSubqueryResult(e: InSubqueryExec, result: Array[_]): Unit = { + val fields = e.getClass.getDeclaredFields + // Field name is mangled by Scala compiler, e.g. "org$apache$...$InSubqueryExec$$result" + val resultField = fields + .find(f => f.getName.endsWith("$result") && !f.getName.contains("Broadcast")) + .getOrElse { + throw new IllegalStateException( + s"Cannot find 'result' field in ${e.getClass.getName}. " + + "Spark version may be incompatible with Comet's DPP implementation.") + } + resultField.setAccessible(true) + resultField.set(e, result) + } + + def commonData: Array[Byte] = serializedPartitionData._1 + def perPartitionData: Array[Array[Byte]] = serializedPartitionData._2 + + // numPartitions for execution - derived from actual DPP-filtered partitions + // Only accessed during execution, not planning + def numPartitions: Int = perPartitionData.length + + override lazy val outputPartitioning: Partitioning = UnknownPartitioning(numPartitions) override lazy val outputOrdering: Seq[SortOrder] = Nil @@ -95,17 +216,26 @@ case class CometIcebergNativeScanExec( } } - private val capturedMetricValues: Seq[MetricValue] = { - originalPlan.metrics - .filterNot { case (name, _) => - // Filter out metrics that are now runtime metrics incremented on the native side - name == "numOutputRows" || name == "numDeletes" || name == "numSplits" - } - .map { case (name, metric) => - val mappedType = mapMetricType(name, metric.metricType) - MetricValue(name, metric.value, mappedType) - } - .toSeq + @transient private lazy val capturedMetricValues: Seq[MetricValue] = { + // Guard against null originalPlan (from doCanonicalize) + if (originalPlan == null) { + Seq.empty + } else { + // Force serializedPartitionData evaluation first - this triggers serializePartitions which + // accesses inputRDD, which triggers Iceberg planning and populates metrics + val _ = serializedPartitionData + + originalPlan.metrics + .filterNot { case (name, _) => + // Filter out metrics that are now runtime metrics incremented on the native side + name == "numOutputRows" || name == "numDeletes" || name == "numSplits" + } + .map { case (name, metric) => + val mappedType = mapMetricType(name, metric.metricType) + MetricValue(name, metric.value, mappedType) + } + .toSeq + } } /** @@ -146,62 +276,78 @@ case class CometIcebergNativeScanExec( baseMetrics ++ icebergMetrics + ("num_splits" -> numSplitsMetric) } + /** Executes using CometExecRDD - planning data is computed lazily on first access. */ + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val nativeMetrics = CometMetricNode.fromCometPlan(this) + CometExecRDD(sparkContext, commonData, perPartitionData, output.length, nativeMetrics) + } + + /** + * Override convertBlock to preserve @transient fields. The parent implementation uses + * makeCopy() which loses transient fields. + */ + override def convertBlock(): CometIcebergNativeScanExec = { + // Serialize the native plan if not already done + val newSerializedPlan = if (serializedPlanOpt.isEmpty) { + val bytes = CometExec.serializeNativePlan(nativeOp) + SerializedPlan(Some(bytes)) + } else { + serializedPlanOpt + } + + // Create new instance preserving transient fields + CometIcebergNativeScanExec( + nativeOp, + output, + originalPlan, + newSerializedPlan, + metadataLocation, + nativeIcebergScanMetadata) + } + override protected def doCanonicalize(): CometIcebergNativeScanExec = { CometIcebergNativeScanExec( nativeOp, output.map(QueryPlan.normalizeExpressions(_, output)), - originalPlan.doCanonicalize(), + null, // Don't need originalPlan for canonicalization SerializedPlan(None), metadataLocation, - numPartitions, - nativeIcebergScanMetadata) + null + ) // Don't need metadata for canonicalization } - override def stringArgs: Iterator[Any] = - Iterator(output, s"$metadataLocation, ${originalPlan.scan.description()}", numPartitions) + override def stringArgs: Iterator[Any] = { + // Use metadata task count to avoid triggering serializedPartitionData during planning + val hasMeta = nativeIcebergScanMetadata != null && nativeIcebergScanMetadata.tasks != null + val taskCount = if (hasMeta) nativeIcebergScanMetadata.tasks.size() else 0 + val scanDesc = if (originalPlan != null) originalPlan.scan.description() else "canonicalized" + // Include runtime filters (DPP) in string representation + val runtimeFiltersStr = if (originalPlan != null && originalPlan.runtimeFilters.nonEmpty) { + s", runtimeFilters=${originalPlan.runtimeFilters.mkString("[", ", ", "]")}" + } else { + "" + } + Iterator(output, s"$metadataLocation, $scanDesc$runtimeFiltersStr", taskCount) + } override def equals(obj: Any): Boolean = { obj match { case other: CometIcebergNativeScanExec => this.metadataLocation == other.metadataLocation && this.output == other.output && - this.serializedPlanOpt == other.serializedPlanOpt && - this.numPartitions == other.numPartitions + this.serializedPlanOpt == other.serializedPlanOpt case _ => false } } override def hashCode(): Int = - Objects.hashCode( - metadataLocation, - output.asJava, - serializedPlanOpt, - numPartitions: java.lang.Integer) + Objects.hashCode(metadataLocation, output.asJava, serializedPlanOpt) } object CometIcebergNativeScanExec { - /** - * Creates a CometIcebergNativeScanExec from a Spark BatchScanExec. - * - * Determines the number of partitions from Iceberg's output partitioning: - * - KeyGroupedPartitioning: Use Iceberg's partition count - * - Other cases: Use the number of InputPartitions from Iceberg's planning - * - * @param nativeOp - * The serialized native operator - * @param scanExec - * The original Spark BatchScanExec - * @param session - * The SparkSession - * @param metadataLocation - * Path to table metadata file - * @param nativeIcebergScanMetadata - * Pre-extracted Iceberg metadata from planning phase - * @return - * A new CometIcebergNativeScanExec - */ + /** Creates a CometIcebergNativeScanExec with deferred partition serialization. */ def apply( nativeOp: Operator, scanExec: BatchScanExec, @@ -209,21 +355,12 @@ object CometIcebergNativeScanExec { metadataLocation: String, nativeIcebergScanMetadata: CometIcebergNativeScanMetadata): CometIcebergNativeScanExec = { - // Determine number of partitions from Iceberg's output partitioning - val numParts = scanExec.outputPartitioning match { - case p: KeyGroupedPartitioning => - p.numPartitions - case _ => - scanExec.inputRDD.getNumPartitions - } - val exec = CometIcebergNativeScanExec( nativeOp, scanExec.output, scanExec, SerializedPlan(None), metadataLocation, - numParts, nativeIcebergScanMetadata) scanExec.logicalLink.foreach(exec.setLogicalLink) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala deleted file mode 100644 index fdf8bf393d..0000000000 --- a/spark/src/main/scala/org/apache/spark/sql/comet/ZippedPartitionsRDD.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.sql.comet - -import org.apache.spark.{Partition, SparkContext, TaskContext} -import org.apache.spark.rdd.{RDD, RDDOperationScope, ZippedPartitionsBaseRDD, ZippedPartitionsPartition} -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Similar to Spark `ZippedPartitionsRDD[1-4]` classes, this class is used to zip partitions of - * the multiple RDDs into a single RDD. Spark `ZippedPartitionsRDD[1-4]` classes only support at - * most 4 RDDs. This class is used to support more than 4 RDDs. This ZipPartitionsRDD is used to - * zip the input sources of the Comet physical plan. So it only zips partitions of ColumnarBatch. - */ -private[spark] class ZippedPartitionsRDD( - sc: SparkContext, - var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch], - var zipRdds: Seq[RDD[ColumnarBatch]], - preservesPartitioning: Boolean = false) - extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) { - - // We need to get the number of partitions in `compute` but `getNumPartitions` is not available - // on the executors. So we need to capture it here. - private val numParts: Int = this.getNumPartitions - - override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = { - val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions - val iterators = - zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context)) - f(iterators, numParts, s.index) - } - - override def clearDependencies(): Unit = { - super.clearDependencies() - zipRdds = null - f = null - } -} - -object ZippedPartitionsRDD { - def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])( - f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch]) - : RDD[ColumnarBatch] = - withScope(sc) { - new ZippedPartitionsRDD(sc, f, rdds) - } - - private[spark] def withScope[U](sc: SparkContext)(body: => U): U = - RDDOperationScope.withScope[U](sc)(body) -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 6f33467efe..9fe5d730ca 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,7 +25,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ -import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -59,6 +58,126 @@ import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregat import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType} import org.apache.comet.serde.operator.CometSink +/** + * Trait for injecting per-partition planning data into operator nodes. + * + * Implementations handle specific operator types (e.g., Iceberg scans, Delta scans). + */ +private[comet] trait PlanDataInjector { + + /** Check if this injector can handle the given operator. */ + def canInject(op: Operator): Boolean + + /** Extract the key used to look up planning data for this operator. */ + def getKey(op: Operator): Option[String] + + /** Inject common + partition data into the operator node. */ + def inject(op: Operator, commonBytes: Array[Byte], partitionBytes: Array[Byte]): Operator +} + +/** + * Registry and utilities for injecting per-partition planning data into operator trees. + */ +private[comet] object PlanDataInjector { + + // Registry of injectors for different operator types + private val injectors: Seq[PlanDataInjector] = Seq( + IcebergPlanDataInjector + // Future: DeltaPlanDataInjector, HudiPlanDataInjector, etc. + ) + + /** + * Injects planning data into an Operator tree by finding nodes that need injection and applying + * the appropriate injector. + * + * Supports joins over multiple tables by matching each operator with its corresponding data + * based on a key (e.g., metadata_location for Iceberg). + */ + def injectPlanData( + op: Operator, + commonByKey: Map[String, Array[Byte]], + partitionByKey: Map[String, Array[Byte]]): Operator = { + val builder = op.toBuilder + + // Try each injector to see if it can handle this operator + for (injector <- injectors if injector.canInject(op)) { + injector.getKey(op) match { + case Some(key) => + (commonByKey.get(key), partitionByKey.get(key)) match { + case (Some(commonBytes), Some(partitionBytes)) => + val injectedOp = injector.inject(op, commonBytes, partitionBytes) + // Copy the injected operator's fields to our builder + builder.clear() + builder.mergeFrom(injectedOp) + case _ => + throw new CometRuntimeException(s"Missing planning data for key: $key") + } + case None => // No key, skip injection + } + } + + // Recursively process children + builder.clearChildren() + op.getChildrenList.asScala.foreach { child => + builder.addChildren(injectPlanData(child, commonByKey, partitionByKey)) + } + + builder.build() + } + + def serializeOperator(op: Operator): Array[Byte] = { + val size = op.getSerializedSize + val bytes = new Array[Byte](size) + val codedOutput = CodedOutputStream.newInstance(bytes) + op.writeTo(codedOutput) + codedOutput.checkNoSpaceLeft() + bytes + } +} + +/** + * Injector for Iceberg scan operators. + */ +private[comet] object IcebergPlanDataInjector extends PlanDataInjector { + import java.nio.ByteBuffer + import java.util.concurrent.ConcurrentHashMap + + // Cache parsed IcebergScanCommon by content to avoid repeated deserialization + // ByteBuffer wrapper provides content-based equality and hashCode + // TODO: This is a static singleton on the executor, should we cap the size (proper LRU cache?) + private val commonCache = + new ConcurrentHashMap[ByteBuffer, OperatorOuterClass.IcebergScanCommon]() + + override def canInject(op: Operator): Boolean = + op.hasIcebergScan && + op.getIcebergScan.getFileScanTasksCount == 0 && + op.getIcebergScan.hasCommon + + override def getKey(op: Operator): Option[String] = + Some(op.getIcebergScan.getCommon.getMetadataLocation) + + override def inject( + op: Operator, + commonBytes: Array[Byte], + partitionBytes: Array[Byte]): Operator = { + val scan = op.getIcebergScan + + // Cache the parsed common data to avoid deserializing on every partition + val cacheKey = ByteBuffer.wrap(commonBytes) + val common = commonCache.computeIfAbsent( + cacheKey, + _ => OperatorOuterClass.IcebergScanCommon.parseFrom(commonBytes)) + + val tasksOnly = OperatorOuterClass.IcebergScan.parseFrom(partitionBytes) + + val scanBuilder = scan.toBuilder + scanBuilder.setCommon(common) + scanBuilder.addAllFileScanTasks(tasksOnly.getFileScanTasksList) + + op.toBuilder.setIcebergScan(scanBuilder).build() + } +} + /** * A Comet physical operator */ @@ -105,6 +224,15 @@ abstract class CometExec extends CometPlan { } } } + + /** Collects all ScalarSubquery expressions from a plan tree. */ + protected def collectSubqueries(sparkPlan: SparkPlan): Seq[ScalarSubquery] = { + val childSubqueries = sparkPlan.children.flatMap(collectSubqueries) + val planSubqueries = sparkPlan.expressions.flatMap { + _.collect { case sub: ScalarSubquery => sub } + } + childSubqueries ++ planSubqueries + } } object CometExec { @@ -290,32 +418,8 @@ abstract class CometNativeExec extends CometExec { case None => (None, Seq.empty) } - def createCometExecIter( - inputs: Seq[Iterator[ColumnarBatch]], - numParts: Int, - partitionIndex: Int): CometExecIterator = { - val it = new CometExecIterator( - CometExec.newIterId, - inputs, - output.length, - serializedPlanCopy, - nativeMetrics, - numParts, - partitionIndex, - broadcastedHadoopConfForEncryption, - encryptedFilePaths) - - setSubqueries(it.id, this) - - Option(TaskContext.get()).foreach { context => - context.addTaskCompletionListener[Unit] { _ => - it.close() - cleanSubqueries(it.id, this) - } - } - - it - } + // Find planning data within this stage (stops at shuffle boundaries). + val (commonByKey, perPartitionByKey) = findAllPlanData(this) // Collect the input ColumnarBatches from the child operators and create a CometExecIterator // to execute the native plan. @@ -395,12 +499,20 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } - if (inputs.nonEmpty) { - ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter) - } else { - val partitionNum = firstNonBroadcastPlanNumPartitions - CometExecRDD(sparkContext, partitionNum)(createCometExecIter) - } + // Unified RDD creation - CometExecRDD handles all cases + val subqueries = collectSubqueries(this) + CometExecRDD( + sparkContext, + inputs.toSeq, + commonByKey, + perPartitionByKey, + serializedPlanCopy, + firstNonBroadcastPlanNumPartitions, + output.length, + nativeMetrics, + subqueries, + broadcastedHadoopConfForEncryption, + encryptedFilePaths) } } @@ -440,6 +552,49 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Find all plan nodes with per-partition planning data in the plan tree. Returns two maps keyed + * by a unique identifier: one for common data (shared across partitions) and one for + * per-partition data. + * + * Currently supports Iceberg scans (keyed by metadata_location). Additional scan types can be + * added by extending this method. + * + * Stops at stage boundaries (shuffle exchanges, etc.) because partition indices are only valid + * within the same stage. + * + * @return + * (commonByKey, perPartitionByKey) - common data is shared, per-partition varies + */ + private def findAllPlanData( + plan: SparkPlan): (Map[String, Array[Byte]], Map[String, Array[Array[Byte]]]) = { + plan match { + // Found an Iceberg scan with planning data + case iceberg: CometIcebergNativeScanExec + if iceberg.commonData.nonEmpty && iceberg.perPartitionData.nonEmpty => + ( + Map(iceberg.metadataLocation -> iceberg.commonData), + Map(iceberg.metadataLocation -> iceberg.perPartitionData)) + + // Broadcast stages are boundaries - don't collect per-partition data from inside them. + // After DPP filtering, broadcast scans may have different partition counts than the + // probe side, causing ArrayIndexOutOfBoundsException in CometExecRDD.getPartitions. + case _: BroadcastQueryStageExec | _: CometBroadcastExchangeExec => + (Map.empty, Map.empty) + + // Stage boundaries - stop searching (partition indices won't align after these) + case _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | + _: CometUnionExec | _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | + _: ReusedExchangeExec | _: CometSparkToColumnarExec => + (Map.empty, Map.empty) + + // Continue searching through other operators, combining results from all children + case _ => + val results = plan.children.map(findAllPlanData) + (results.flatMap(_._1).toMap, results.flatMap(_._2).toMap) + } + } + /** * Converts this native Comet operator and its children into a native block which can be * executed as a whole (i.e., in a single JNI call) from the native side. diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala new file mode 100644 index 0000000000..1ff0935041 --- /dev/null +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec + +trait ShimSubqueryBroadcast { + + /** + * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, + * Spark 4.x has `indices: Seq[Int]`. + */ + def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { + Seq(sab.index) + } +} diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala new file mode 100644 index 0000000000..1ff0935041 --- /dev/null +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec + +trait ShimSubqueryBroadcast { + + /** + * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, + * Spark 4.x has `indices: Seq[Int]`. + */ + def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { + Seq(sab.index) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala new file mode 100644 index 0000000000..417dfd46b7 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.execution.SubqueryAdaptiveBroadcastExec + +trait ShimSubqueryBroadcast { + + /** + * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, + * Spark 4.x has `indices: Seq[Int]`. + */ + def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { + sab.indices + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala index f3c8a8b2a6..30521dbad7 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala @@ -2295,6 +2295,84 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { } } + test("runtime filtering - multiple DPP filters on two partition columns") { + assume(icebergAvailable, "Iceberg not available") + withTempIcebergDir { warehouseDir => + val dimDir = new File(warehouseDir, "dim_parquet") + withSQLConf( + "spark.sql.catalog.runtime_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.runtime_cat.type" -> "hadoop", + "spark.sql.catalog.runtime_cat.warehouse" -> warehouseDir.getAbsolutePath, + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { + + // Create table partitioned by TWO columns: (data, bucket(8, id)) + // This mimics Iceberg's testMultipleRuntimeFilters + spark.sql(""" + CREATE TABLE runtime_cat.db.multi_dpp_fact ( + id BIGINT, + data STRING, + date DATE, + ts TIMESTAMP + ) USING iceberg + PARTITIONED BY (data, bucket(8, id)) + """) + + // Insert data - 99 rows with varying data and id values + val df = spark + .range(1, 100) + .selectExpr( + "id", + "CAST(DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) AS STRING) as data", + "DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) as date", + "CAST(DATE_ADD(DATE '1970-01-01', CAST(id % 4 AS INT)) AS TIMESTAMP) as ts") + df.coalesce(1) + .write + .format("iceberg") + .option("fanout-enabled", "true") + .mode("append") + .saveAsTable("runtime_cat.db.multi_dpp_fact") + + // Create dimension table with specific id=1, data='1970-01-02' + spark + .createDataFrame(Seq((1L, java.sql.Date.valueOf("1970-01-02"), "1970-01-02"))) + .toDF("id", "date", "data") + .write + .parquet(dimDir.getAbsolutePath) + spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("dim") + + // Join on BOTH partition columns - this creates TWO DPP filters + val query = + """SELECT /*+ BROADCAST(d) */ f.* + |FROM runtime_cat.db.multi_dpp_fact f + |JOIN dim d ON f.id = d.id AND f.data = d.data + |WHERE d.date = DATE '1970-01-02'""".stripMargin + + // Verify plan has 2 dynamic pruning expressions + val df2 = spark.sql(query) + val planStr = df2.queryExecution.executedPlan.toString + // Count "dynamicpruningexpression(" to avoid matching "dynamicpruning#N" references + val dppCount = "dynamicpruningexpression\\(".r.findAllIn(planStr).length + assert(dppCount == 2, s"Expected 2 DPP expressions but found $dppCount in:\n$planStr") + + // Verify native Iceberg scan is used and DPP actually pruned partitions + val (_, cometPlan) = checkSparkAnswer(query) + val icebergScans = collectIcebergNativeScans(cometPlan) + assert( + icebergScans.nonEmpty, + s"Expected CometIcebergNativeScanExec but found none. Plan:\n$cometPlan") + // With 4 data values x 8 buckets = up to 32 partitions total + // DPP on (data='1970-01-02', bucket(id=1)) should prune to 1 + val numPartitions = icebergScans.head.numPartitions + assert(numPartitions == 1, s"Expected DPP to prune to 1 partition but got $numPartitions") + + spark.sql("DROP TABLE runtime_cat.db.multi_dpp_fact") + } + } + } + test("runtime filtering - join with dynamic partition pruning") { assume(icebergAvailable, "Iceberg not available") withTempIcebergDir { warehouseDir => @@ -2303,11 +2381,14 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { "spark.sql.catalog.runtime_cat" -> "org.apache.iceberg.spark.SparkCatalog", "spark.sql.catalog.runtime_cat.type" -> "hadoop", "spark.sql.catalog.runtime_cat.warehouse" -> warehouseDir.getAbsolutePath, + // Prevent fact table from being broadcast (force dimension to be broadcast) + "spark.sql.autoBroadcastJoinThreshold" -> "1KB", CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true") { - // Create partitioned Iceberg table (fact table) + // Create partitioned Iceberg table (fact table) with 3 partitions + // Add enough data to prevent broadcast spark.sql(""" CREATE TABLE runtime_cat.db.fact_table ( id BIGINT, @@ -2323,7 +2404,11 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { (1, 'a', DATE '1970-01-01'), (2, 'b', DATE '1970-01-02'), (3, 'c', DATE '1970-01-02'), - (4, 'd', DATE '1970-01-03') + (4, 'd', DATE '1970-01-03'), + (5, 'e', DATE '1970-01-01'), + (6, 'f', DATE '1970-01-02'), + (7, 'g', DATE '1970-01-03'), + (8, 'h', DATE '1970-01-01') """) // Create dimension table (Parquet) in temp directory @@ -2335,8 +2420,9 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { spark.read.parquet(dimDir.getAbsolutePath).createOrReplaceTempView("dim") // This join should trigger dynamic partition pruning + // Use BROADCAST hint to force dimension table to be broadcast val query = - """SELECT f.* FROM runtime_cat.db.fact_table f + """SELECT /*+ BROADCAST(d) */ f.* FROM runtime_cat.db.fact_table f |JOIN dim d ON f.date = d.date AND d.id = 1 |ORDER BY f.id""".stripMargin @@ -2348,13 +2434,17 @@ class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { planStr.contains("dynamicpruning"), s"Expected dynamic pruning in plan but got:\n$planStr") - // Check results match Spark - // Note: AQE re-plans after subquery executes, converting dynamicpruningexpression(...) - // to dynamicpruningexpression(true), which allows native Iceberg scan to proceed. - // This is correct behavior - no actual subquery to wait for after AQE re-planning. - // However, the rest of the still contains non-native operators because CometExecRule - // doesn't run again. - checkSparkAnswer(df) + // Should now use native Iceberg scan with DPP + checkIcebergNativeScan(query) + + // Verify DPP actually pruned partitions (should only scan 1 of 3 partitions) + val (_, cometPlan) = checkSparkAnswer(query) + val icebergScans = collectIcebergNativeScans(cometPlan) + assert( + icebergScans.nonEmpty, + s"Expected CometIcebergNativeScanExec but found none. Plan:\n$cometPlan") + val numPartitions = icebergScans.head.numPartitions + assert(numPartitions == 1, s"Expected DPP to prune to 1 partition but got $numPartitions") spark.sql("DROP TABLE runtime_cat.db.fact_table") }