diff --git a/src/query/catalog/src/sbbf.rs b/src/query/catalog/src/sbbf.rs index 368a5f4ac2893..df01154f39722 100644 --- a/src/query/catalog/src/sbbf.rs +++ b/src/query/catalog/src/sbbf.rs @@ -239,7 +239,7 @@ impl Sbbf { /// Create a new [Sbbf] with given number of bytes, the exact number of bytes will be adjusted /// to the next power of two bounded by [BITSET_MIN_LENGTH] and [BITSET_MAX_LENGTH]. - pub(crate) fn new_with_num_of_bytes(num_bytes: usize) -> Self { + pub fn new_with_num_of_bytes(num_bytes: usize) -> Self { let num_bytes = optimal_num_of_bytes(num_bytes); assert_eq!(num_bytes % size_of::(), 0); let num_blocks = num_bytes / size_of::(); @@ -307,6 +307,52 @@ impl Sbbf { pub fn estimated_memory_size(&self) -> usize { self.0.capacity() * std::mem::size_of::() } + + /// Serialize the bloom filter into a little-endian byte array. + /// The layout is a contiguous sequence of blocks, each block consisting + /// of 8 u32 values in little-endian order. + pub fn to_bytes(&self) -> Vec { + let mut bytes = Vec::with_capacity(self.0.len() * size_of::()); + for block in &self.0 { + for value in block.0 { + bytes.extend_from_slice(&value.to_le_bytes()); + } + } + bytes + } + + /// Deserialize a bloom filter from bytes produced by `to_bytes`. + pub fn from_bytes(bytes: &[u8]) -> Result { + if bytes.len() % size_of::() != 0 { + return Err(format!( + "Invalid bloom filter bytes length {}, expected multiple of {}", + bytes.len(), + size_of::() + )); + } + + let num_blocks = bytes.len() / size_of::(); + if num_blocks == 0 { + return Ok(Sbbf(Vec::new())); + } + + let mut blocks = Vec::with_capacity(num_blocks); + let mut offset = 0; + for _ in 0..num_blocks { + let mut arr = [0u32; 8]; + for value in &mut arr { + let end = offset + size_of::(); + let chunk = bytes + .get(offset..end) + .ok_or_else(|| "Invalid bloom filter bytes".to_string())?; + *value = u32::from_le_bytes(chunk.try_into().unwrap()); + offset = end; + } + blocks.push(Block(arr)); + } + + Ok(Sbbf(blocks)) + } } impl SbbfAtomic { @@ -320,7 +366,7 @@ impl SbbfAtomic { Ok(Self::new_with_num_of_bytes(num_bits / 8)) } - pub(crate) fn new_with_num_of_bytes(num_bytes: usize) -> Self { + pub fn new_with_num_of_bytes(num_bytes: usize) -> Self { let num_bytes = optimal_num_of_bytes(num_bytes); assert_eq!(size_of::(), size_of::()); assert_eq!(num_bytes % size_of::(), 0); diff --git a/src/query/service/src/physical_plans/runtime_filter/builder.rs b/src/query/service/src/physical_plans/runtime_filter/builder.rs index 60cf42e3685b5..1f0c87ee24b4a 100644 --- a/src/query/service/src/physical_plans/runtime_filter/builder.rs +++ b/src/query/service/src/physical_plans/runtime_filter/builder.rs @@ -18,8 +18,10 @@ use std::sync::Arc; use databend_common_catalog::table_context::TableContext; use databend_common_exception::Result; use databend_common_expression::types::DataType; +use databend_common_expression::Expr; use databend_common_expression::RemoteExpr; use databend_common_functions::BUILTIN_FUNCTIONS; +use databend_common_sql::optimizer::ir::ColumnStatSet; use databend_common_sql::optimizer::ir::SExpr; use databend_common_sql::plans::Exchange; use databend_common_sql::plans::Join; @@ -113,6 +115,11 @@ pub async fn build_runtime_filter( let mut filters = Vec::new(); + // Derive statistics for the build side to estimate NDV of join keys. + let build_rel_expr = databend_common_sql::optimizer::ir::RelExpr::with_s_expr(build_side); + let build_stat_info = build_rel_expr.derive_cardinality()?; + let build_column_stats = &build_stat_info.statistics.column_stats; + let probe_side = s_expr.probe_side_child(); // Process each probe key that has runtime filter information @@ -144,10 +151,17 @@ pub async fn build_runtime_filter( let build_table_rows = get_build_table_rows(ctx.clone(), metadata, build_table_index).await?; - let data_type = build_key - .as_expr(&BUILTIN_FUNCTIONS) - .data_type() - .remove_nullable(); + let build_key_expr = build_key.as_expr(&BUILTIN_FUNCTIONS); + + // Estimate NDV for the build side join key using optimizer statistics. + // Handles all RemoteExpr variants by looking at the column references inside + // the expression. If the expression is constant, NDV is 1. If it contains + // exactly one column reference, reuse that column's NDV. Otherwise, fall + // back to the overall build-side cardinality. + let build_key_ndv = estimate_build_key_ndv(&build_key_expr, build_column_stats) + .unwrap_or_else(|| build_stat_info.cardinality.ceil() as u64); + + let data_type = build_key_expr.data_type().remove_nullable(); let id = metadata.write().next_runtime_filter_id(); let enable_bloom_runtime_filter = is_type_supported_for_bloom_filter(&data_type); @@ -159,6 +173,7 @@ pub async fn build_runtime_filter( id, build_key: build_key.clone(), probe_targets, + build_key_ndv, build_table_rows, enable_bloom_runtime_filter, enable_inlist_runtime_filter: true, @@ -170,6 +185,23 @@ pub async fn build_runtime_filter( Ok(PhysicalRuntimeFilters { filters }) } +fn estimate_build_key_ndv( + build_key: &Expr, + build_column_stats: &ColumnStatSet, +) -> Option { + let mut column_refs = build_key.column_refs(); + if column_refs.is_empty() { + return Some(1); + } + + if column_refs.len() == 1 { + let (id, _) = column_refs.drain().next().unwrap(); + build_column_stats.get(&id).map(|s| s.ndv.ceil() as u64) + } else { + None + } +} + async fn get_build_table_rows( ctx: Arc, metadata: &MetadataRef, diff --git a/src/query/service/src/physical_plans/runtime_filter/types.rs b/src/query/service/src/physical_plans/runtime_filter/types.rs index 11a7a9992f5c6..cf1f62e93ffe6 100644 --- a/src/query/service/src/physical_plans/runtime_filter/types.rs +++ b/src/query/service/src/physical_plans/runtime_filter/types.rs @@ -42,6 +42,9 @@ pub struct PhysicalRuntimeFilter { /// All probe targets in this list are in the same equivalence class pub probe_targets: Vec<(RemoteExpr, usize)>, + /// Estimated NDV of the build side join key, derived from optimizer statistics. + pub build_key_ndv: u64, + pub build_table_rows: Option, /// Enable bloom filter for this runtime filter diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/desc.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/desc.rs index ae2a25d06733b..bef2e6202e011 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/desc.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/desc.rs @@ -70,6 +70,7 @@ pub struct RuntimeFilterDesc { pub id: usize, pub build_key: Expr, pub probe_targets: Vec<(Expr, usize)>, + pub build_key_ndv: u64, pub build_table_rows: Option, pub enable_bloom_runtime_filter: bool, pub enable_inlist_runtime_filter: bool, @@ -98,6 +99,7 @@ impl From<&PhysicalRuntimeFilter> for RuntimeFilterDesc { .iter() .map(|(probe_key, scan_id)| (probe_key.as_expr(&BUILTIN_FUNCTIONS), *scan_id)) .collect(), + build_key_ndv: runtime_filter.build_key_ndv, build_table_rows: runtime_filter.build_table_rows, enable_bloom_runtime_filter: runtime_filter.enable_bloom_runtime_filter, enable_inlist_runtime_filter: runtime_filter.enable_inlist_runtime_filter, diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/builder.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/builder.rs index 9af1bec4b542b..f35606c6f5853 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/builder.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/builder.rs @@ -12,191 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; -use std::time::Instant; - -use databend_common_base::runtime::profile::Profile; -use databend_common_base::runtime::profile::ProfileStatisticsName; -use databend_common_exception::Result; -use databend_common_expression::type_check; -use databend_common_expression::types::DataType; -use databend_common_expression::Column; -use databend_common_expression::DataBlock; -use databend_common_expression::Evaluator; -use databend_common_expression::Expr; -use databend_common_expression::FunctionContext; -use databend_common_expression::RawExpr; -use databend_common_expression::Scalar; -use databend_common_functions::BUILTIN_FUNCTIONS; - -use super::packet::JoinRuntimeFilterPacket; -use super::packet::RuntimeFilterPacket; -use super::packet::SerializableDomain; use crate::pipelines::processors::transforms::hash_join::desc::RuntimeFilterDesc; -use crate::pipelines::processors::transforms::hash_join::util::hash_by_method_for_bloom; - -struct JoinRuntimeFilterPacketBuilder<'a> { - build_key_column: Column, - func_ctx: &'a FunctionContext, - inlist_threshold: usize, - bloom_threshold: usize, - min_max_threshold: usize, - selectivity_threshold: u64, -} - -impl<'a> JoinRuntimeFilterPacketBuilder<'a> { - fn new( - data_blocks: &'a [DataBlock], - func_ctx: &'a FunctionContext, - build_key: &Expr, - inlist_threshold: usize, - bloom_threshold: usize, - min_max_threshold: usize, - selectivity_threshold: u64, - ) -> Result { - let build_key_column = Self::eval_build_key_column(data_blocks, func_ctx, build_key)?; - Ok(Self { - func_ctx, - build_key_column, - inlist_threshold, - bloom_threshold, - min_max_threshold, - selectivity_threshold, - }) - } - fn build(&self, desc: &RuntimeFilterDesc) -> Result { - if !should_enable_runtime_filter( - desc, - self.build_key_column.len(), - self.selectivity_threshold, - ) { - return Ok(RuntimeFilterPacket { - id: desc.id, - inlist: None, - min_max: None, - bloom: None, - }); - } - let start = Instant::now(); - - let min_max_start = Instant::now(); - let min_max = self - .enable_min_max(desc) - .then(|| self.build_min_max()) - .transpose()?; - let min_max_time = min_max_start.elapsed(); - - let inlist_start = Instant::now(); - let inlist = self - .enable_inlist(desc) - .then(|| self.build_inlist()) - .transpose()?; - let inlist_time = inlist_start.elapsed(); - - let bloom_start = Instant::now(); - let bloom = self - .enable_bloom(desc) - .then(|| self.build_bloom(desc)) - .transpose()?; - let bloom_time = bloom_start.elapsed(); - - let total_time = start.elapsed(); - - Profile::record_usize_profile( - ProfileStatisticsName::RuntimeFilterBuildTime, - total_time.as_nanos() as usize, - ); - - log::info!( - "RUNTIME-FILTER: Built filter {} - total: {:?}, min_max: {:?}, inlist: {:?}, bloom: {:?}, rows: {}", - desc.id, total_time, min_max_time, inlist_time, bloom_time, self.build_key_column.len() - ); - - Ok(RuntimeFilterPacket { - id: desc.id, - min_max, - inlist, - bloom, - }) - } - - fn enable_min_max(&self, desc: &RuntimeFilterDesc) -> bool { - desc.enable_min_max_runtime_filter && self.build_key_column.len() < self.min_max_threshold - } - - fn enable_inlist(&self, desc: &RuntimeFilterDesc) -> bool { - desc.enable_inlist_runtime_filter && self.build_key_column.len() < self.inlist_threshold - } - - fn enable_bloom(&self, desc: &RuntimeFilterDesc) -> bool { - if !desc.enable_bloom_runtime_filter { - return false; - } - - if self.build_key_column.len() >= self.bloom_threshold { - return false; - } - - true - } - - fn build_min_max(&self) -> Result { - let domain = self.build_key_column.remove_nullable().domain(); - let (min, max) = domain.to_minmax(); - Ok(SerializableDomain { min, max }) - } - - fn build_inlist(&self) -> Result { - self.dedup_column(&self.build_key_column) - } - - fn build_bloom(&self, desc: &RuntimeFilterDesc) -> Result> { - let data_type = desc.build_key.data_type(); - let num_rows = self.build_key_column.len(); - let method = DataBlock::choose_hash_method_with_types(&[data_type.clone()])?; - let mut hashes = Vec::with_capacity(num_rows); - let key_columns = &[self.build_key_column.clone().into()]; - hash_by_method_for_bloom(&method, key_columns.into(), num_rows, &mut hashes)?; - Ok(hashes) - } - - fn eval_build_key_column( - data_blocks: &[DataBlock], - func_ctx: &FunctionContext, - build_key: &Expr, - ) -> Result { - let mut columns = Vec::with_capacity(data_blocks.len()); - for block in data_blocks.iter() { - let evaluator = Evaluator::new(block, func_ctx, &BUILTIN_FUNCTIONS); - let column = evaluator - .run(build_key)? - .convert_to_full_column(build_key.data_type(), block.num_rows()); - columns.push(column); - } - Column::concat_columns(columns.into_iter()) - } - - fn dedup_column(&self, column: &Column) -> Result { - let array = RawExpr::Constant { - span: None, - scalar: Scalar::Array(column.clone()), - data_type: Some(DataType::Array(Box::new(column.data_type()))), - }; - let distinct_list = RawExpr::FunctionCall { - span: None, - name: "array_distinct".to_string(), - params: vec![], - args: vec![array], - }; - - let empty_key_block = DataBlock::empty(); - let evaluator = Evaluator::new(&empty_key_block, self.func_ctx, &BUILTIN_FUNCTIONS); - let value = evaluator.run(&type_check::check(&distinct_list, &BUILTIN_FUNCTIONS)?)?; - let array = value.into_scalar().unwrap().into_array().unwrap(); - Ok(array) - } -} - pub(super) fn should_enable_runtime_filter( desc: &RuntimeFilterDesc, build_num_rows: usize, @@ -238,48 +54,3 @@ pub(super) fn should_enable_runtime_filter( false } } - -pub fn build_runtime_filter_packet( - build_chunks: &[DataBlock], - build_num_rows: usize, - runtime_filter_desc: &[RuntimeFilterDesc], - func_ctx: &FunctionContext, - inlist_threshold: usize, - bloom_threshold: usize, - min_max_threshold: usize, - selectivity_threshold: u64, - is_spill_happened: bool, -) -> Result { - if is_spill_happened { - return Ok(JoinRuntimeFilterPacket::disable_all( - runtime_filter_desc, - build_num_rows, - )); - } - if build_num_rows == 0 { - return Ok(JoinRuntimeFilterPacket { - packets: None, - build_rows: build_num_rows, - }); - } - let mut runtime_filters = HashMap::new(); - for rf in runtime_filter_desc { - runtime_filters.insert( - rf.id, - JoinRuntimeFilterPacketBuilder::new( - build_chunks, - func_ctx, - &rf.build_key, - inlist_threshold, - bloom_threshold, - min_max_threshold, - selectivity_threshold, - )? - .build(rf)?, - ); - } - Ok(JoinRuntimeFilterPacket { - packets: Some(runtime_filters), - build_rows: build_num_rows, - }) -} diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/convert.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/convert.rs index 90f2a8a7737fa..ed7500f22b806 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/convert.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/convert.rs @@ -36,7 +36,9 @@ use databend_common_expression::Scalar; use databend_common_functions::BUILTIN_FUNCTIONS; use super::builder::should_enable_runtime_filter; +use super::packet::BloomPayload; use super::packet::JoinRuntimeFilterPacket; +use super::packet::SerializableBloomFilter; use super::packet::SerializableDomain; use crate::pipelines::processors::transforms::hash_join::desc::RuntimeFilterDesc; use crate::pipelines::processors::transforms::hash_join::util::min_max_filter; @@ -119,6 +121,52 @@ pub async fn build_runtime_filter_infos( Ok(filters) } +/// Convert bloom payloads that are using hash lists into serialized bloom filters. +/// This is mainly used in distributed hash shuffle joins before global merging, +/// so that only compact bloom filters are transmitted between nodes. +pub fn convert_packet_bloom_hashes_to_filter( + packet: &mut JoinRuntimeFilterPacket, + runtime_filter_descs: &HashMap, + max_threads: usize, +) -> Result<()> { + let Some(ref mut map) = packet.packets else { + return Ok(()); + }; + + for (id, rf) in map.iter_mut() { + let desc = match runtime_filter_descs.get(id) { + Some(d) => *d, + None => continue, + }; + + // If we don't have global build table statistics, we cannot guarantee + // consistent bloom filter size across nodes. Disable bloom for this + // runtime filter in distributed mode. + if desc.build_table_rows.is_none() { + rf.bloom = None; + continue; + } + + if let Some(bloom) = rf.bloom.take() { + rf.bloom = Some(match bloom { + BloomPayload::Hashes(hashes) => { + // If there are no hashes, keep bloom disabled. + if hashes.is_empty() { + continue; + } + + let filter = build_sbbf_from_hashes(hashes, max_threads, rf.id, true)?; + BloomPayload::Filter(SerializableBloomFilter { + data: filter.to_bytes(), + }) + } + BloomPayload::Filter(filter) => BloomPayload::Filter(filter), + }); + } + } + Ok(()) +} + fn build_inlist_filter(inlist: Column, probe_key: &Expr) -> Result> { if inlist.len() == 0 { return Ok(Expr::Constant(Constant { @@ -256,38 +304,36 @@ fn build_min_max_filter( Ok(min_max_filter) } -async fn build_bloom_filter( +const FIXED_SBBF_BYTES: usize = 64 * 1024 * 1024; + +fn build_sbbf_from_hashes( bloom: Vec, - probe_key: &Expr, max_threads: usize, filter_id: usize, -) -> Result { - let probe_key = match probe_key { - Expr::ColumnRef(col) => col, - // Support simple cast that only changes nullability, e.g. CAST(col AS Nullable(T)) - Expr::Cast(cast) => match cast.expr.as_ref() { - Expr::ColumnRef(col) => col, - _ => unreachable!(), - }, - _ => unreachable!(), - }; - let column_name = probe_key.id.to_string(); + fixed_size: bool, +) -> Result { let total_items = bloom.len(); if total_items < 3_000_000 { - let mut filter = Sbbf::new_with_ndv_fpp(total_items as u64, 0.01) - .map_err(|e| ErrorCode::Internal(e.to_string()))?; + let mut filter = if fixed_size { + Sbbf::new_with_num_of_bytes(FIXED_SBBF_BYTES) + } else { + Sbbf::new_with_ndv_fpp(total_items as u64, 0.01) + .map_err(|e| ErrorCode::Internal(e.to_string()))? + }; filter.insert_hash_batch(&bloom); - return Ok(RuntimeFilterBloom { - column_name, - filter: Arc::new(filter), - }); + return Ok(filter); } let start = std::time::Instant::now(); - let builder = SbbfAtomic::new_with_ndv_fpp(total_items as u64, 0.01) - .map_err(|e| ErrorCode::Internal(e.to_string()))? - .insert_hash_batch_parallel(bloom, max_threads); + let builder = if fixed_size { + SbbfAtomic::new_with_num_of_bytes(FIXED_SBBF_BYTES) + .insert_hash_batch_parallel(bloom, max_threads) + } else { + SbbfAtomic::new_with_ndv_fpp(total_items as u64, 0.01) + .map_err(|e| ErrorCode::Internal(e.to_string()))? + .insert_hash_batch_parallel(bloom, max_threads) + }; let filter = builder.finish(); log::info!( "filter_id: {}, build_time: {:?}", @@ -295,10 +341,42 @@ async fn build_bloom_filter( start.elapsed() ); - Ok(RuntimeFilterBloom { - column_name, - filter: Arc::new(filter), - }) + Ok(filter) +} + +async fn build_bloom_filter( + bloom: BloomPayload, + probe_key: &Expr, + max_threads: usize, + filter_id: usize, +) -> Result { + let probe_key = match probe_key { + Expr::ColumnRef(col) => col, + // Support simple cast that only changes nullability, e.g. CAST(col AS Nullable(T)) + Expr::Cast(cast) => match cast.expr.as_ref() { + Expr::ColumnRef(col) => col, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let column_name = probe_key.id.to_string(); + match bloom { + BloomPayload::Hashes(hashes) => { + let filter = build_sbbf_from_hashes(hashes, max_threads, filter_id, false)?; + Ok(RuntimeFilterBloom { + column_name, + filter: Arc::new(filter), + }) + } + BloomPayload::Filter(serialized) => { + let filter = Sbbf::from_bytes(&serialized.data) + .map_err(|e| ErrorCode::Internal(e.to_string()))?; + Ok(RuntimeFilterBloom { + column_name, + filter: Arc::new(filter), + }) + } + } } #[cfg(test)] diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/interface.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/interface.rs index d0f1fb003fb1e..1a5c344e7a23d 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/interface.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/interface.rs @@ -18,6 +18,7 @@ use databend_common_exception::Result; use databend_common_storages_fuse::TableContext; use super::convert::build_runtime_filter_infos; +use super::convert::convert_packet_bloom_hashes_to_filter; use super::global::get_global_runtime_filter_packet; use crate::pipelines::processors::transforms::JoinRuntimeFilterPacket; use crate::pipelines::processors::HashJoinBuildState; @@ -28,7 +29,22 @@ pub async fn build_and_push_down_runtime_filter( ) -> Result<()> { let overall_start = Instant::now(); + let max_threads = join.ctx.get_settings().get_max_threads()? as usize; + let runtime_filter_descs: std::collections::HashMap< + usize, + &crate::pipelines::processors::transforms::hash_join::desc::RuntimeFilterDesc, + > = join + .runtime_filter_desc() + .iter() + .map(|r| (r.id, r)) + .collect(); + if let Some(broadcast_id) = join.broadcast_id { + // For distributed hash shuffle joins, convert bloom hashes into compact + // bloom filters before performing global merge so that only filters are + // transmitted between nodes. + convert_packet_bloom_hashes_to_filter(&mut packet, &runtime_filter_descs, max_threads)?; + let merge_start = Instant::now(); packet = get_global_runtime_filter_packet(broadcast_id, packet, &join.ctx).await?; let merge_time = merge_start.elapsed(); @@ -38,16 +54,10 @@ pub async fn build_and_push_down_runtime_filter( ); } - let runtime_filter_descs = join - .runtime_filter_desc() - .iter() - .map(|r| (r.id, r)) - .collect(); let selectivity_threshold = join .ctx .get_settings() .get_join_runtime_filter_selectivity_threshold()?; - let max_threads = join.ctx.get_settings().get_max_threads()? as usize; let build_rows = packet.build_rows; let runtime_filter_infos = build_runtime_filter_infos( packet, diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/local_builder.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/local_builder.rs index e91b51ce1ffcf..5134157982f16 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/local_builder.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/local_builder.rs @@ -27,6 +27,7 @@ use databend_common_expression::Scalar; use databend_common_functions::BUILTIN_FUNCTIONS; use crate::pipelines::processors::transforms::hash_join::desc::RuntimeFilterDesc; +use crate::pipelines::processors::transforms::hash_join::runtime_filter::packet::BloomPayload; use crate::pipelines::processors::transforms::hash_join::runtime_filter::packet::JoinRuntimeFilterPacket; use crate::pipelines::processors::transforms::hash_join::runtime_filter::packet::RuntimeFilterPacket; use crate::pipelines::processors::transforms::hash_join::runtime_filter::packet::SerializableDomain; @@ -153,7 +154,7 @@ impl SingleFilterBuilder { None }; - let bloom = self.bloom_hashes.take(); + let bloom = self.bloom_hashes.take().map(BloomPayload::Hashes); Ok(RuntimeFilterPacket { id: self.id, diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/merge.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/merge.rs index 0f6d40a04da1a..959d7c19a40f1 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/merge.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/merge.rs @@ -14,11 +14,14 @@ use std::collections::HashMap; +use databend_common_catalog::sbbf::Sbbf; use databend_common_exception::Result; use databend_common_expression::Column; +use super::packet::BloomPayload; use super::packet::JoinRuntimeFilterPacket; use super::packet::RuntimeFilterPacket; +use super::packet::SerializableBloomFilter; use super::packet::SerializableDomain; pub fn merge_join_runtime_filter_packets( @@ -121,23 +124,93 @@ fn merge_min_max( Some(SerializableDomain { min, max }) } -fn merge_bloom(packets: &[HashMap], rf_id: usize) -> Option> { +fn merge_bloom( + packets: &[HashMap], + rf_id: usize, +) -> Option { if packets .iter() .any(|packet| packet.get(&rf_id).unwrap().bloom.is_none()) { return None; } - let mut bloom = packets[0] - .get(&rf_id) - .unwrap() - .bloom - .as_ref() - .unwrap() - .clone(); - for packet in packets.iter().skip(1) { - let other = packet.get(&rf_id).unwrap().bloom.as_ref().unwrap(); - bloom.extend_from_slice(other); + + let first = packets[0].get(&rf_id).unwrap().bloom.as_ref().unwrap(); + match first { + BloomPayload::Hashes(_) => { + // Local merge path: concatenate hashes + let mut merged = match first { + BloomPayload::Hashes(hashes) => hashes.clone(), + _ => unreachable!(), + }; + + for packet in packets.iter().skip(1) { + let other = packet.get(&rf_id).unwrap().bloom.as_ref().unwrap(); + match other { + BloomPayload::Hashes(hashes) => merged.extend_from_slice(hashes), + BloomPayload::Filter(_) => { + // Mixed variants are not expected today. Fallback to disabling bloom. + log::warn!( + "RUNTIME-FILTER: mixed bloom payload variants detected for id {}, disabling bloom merge", + rf_id + ); + return None; + } + } + } + Some(BloomPayload::Hashes(merged)) + } + BloomPayload::Filter(_) => { + // Global merge path: union serialized bloom filters + let mut base_bytes = match first { + BloomPayload::Filter(f) => f.data.clone(), + _ => unreachable!(), + }; + + let mut base = match Sbbf::from_bytes(&base_bytes) { + Ok(bf) => bf, + Err(e) => { + log::warn!( + "RUNTIME-FILTER: failed to deserialize bloom filter for id {}: {}", + rf_id, + e + ); + return None; + } + }; + + for packet in packets.iter().skip(1) { + let other = packet.get(&rf_id).unwrap().bloom.as_ref().unwrap(); + let bytes = match other { + BloomPayload::Filter(f) => &f.data, + BloomPayload::Hashes(_) => { + log::warn!( + "RUNTIME-FILTER: mixed bloom payload variants detected for id {}, disabling bloom merge", + rf_id + ); + return None; + } + }; + + let other_bf = match Sbbf::from_bytes(bytes) { + Ok(bf) => bf, + Err(e) => { + log::warn!( + "RUNTIME-FILTER: failed to deserialize bloom filter for id {}: {}", + rf_id, + e + ); + return None; + } + }; + + base.union(&other_bf); + } + + base_bytes = base.to_bytes(); + Some(BloomPayload::Filter(SerializableBloomFilter { + data: base_bytes, + })) + } } - Some(bloom) } diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/mod.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/mod.rs index 618512a3f5f79..2689c169a8960 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/mod.rs @@ -20,8 +20,8 @@ mod local_builder; mod merge; mod packet; -pub use builder::build_runtime_filter_packet; pub use convert::build_runtime_filter_infos; +pub use convert::convert_packet_bloom_hashes_to_filter; pub use global::get_global_runtime_filter_packet; pub use interface::build_and_push_down_runtime_filter; pub use local_builder::RuntimeFilterLocalBuilder; diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/packet.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/packet.rs index 9e30ac18b5bb1..538457f996758 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/packet.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/runtime_filter/packet.rs @@ -23,6 +23,18 @@ use databend_common_expression::Scalar; use crate::pipelines::processors::transforms::RuntimeFilterDesc; +/// Bloom filter payload used in runtime filter packet. +#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)] +pub enum BloomPayload { + Hashes(Vec), + Filter(SerializableBloomFilter), +} + +#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, PartialEq)] +pub struct SerializableBloomFilter { + pub data: Vec, +} + /// Represents a runtime filter that can be transmitted and merged. /// /// # Fields @@ -30,13 +42,13 @@ use crate::pipelines::processors::transforms::RuntimeFilterDesc; /// * `id` - Unique identifier for each runtime filter, corresponds one-to-one with `(build key, probe key)` pair /// * `inlist` - Deduplicated list of build key column /// * `min_max` - The min and max values of the build column -/// * `bloom` - The deduplicated hashes of the build column +/// * `bloom` - Bloom filter payload for the build column #[derive(serde::Serialize, serde::Deserialize, Clone, Default, PartialEq)] pub struct RuntimeFilterPacket { pub id: usize, pub inlist: Option, pub min_max: Option, - pub bloom: Option>, + pub bloom: Option, } impl Debug for RuntimeFilterPacket { diff --git a/src/query/service/src/pipelines/processors/transforms/new_hash_join/runtime_filter.rs b/src/query/service/src/pipelines/processors/transforms/new_hash_join/runtime_filter.rs index 4757d0510a7d9..3221de91a18b6 100644 --- a/src/query/service/src/pipelines/processors/transforms/new_hash_join/runtime_filter.rs +++ b/src/query/service/src/pipelines/processors/transforms/new_hash_join/runtime_filter.rs @@ -22,6 +22,7 @@ use databend_common_expression::FunctionContext; use crate::physical_plans::HashJoin; use crate::pipelines::processors::transforms::build_runtime_filter_infos; +use crate::pipelines::processors::transforms::convert_packet_bloom_hashes_to_filter; use crate::pipelines::processors::transforms::get_global_runtime_filter_packet; use crate::pipelines::processors::transforms::JoinRuntimeFilterPacket; use crate::pipelines::processors::transforms::RuntimeFilterDesc; @@ -81,11 +82,19 @@ impl RuntimeFiltersDesc { } pub async fn globalization(&self, mut packet: JoinRuntimeFilterPacket) -> Result<()> { + let runtime_filter_descs: std::collections::HashMap = + self.filters_desc.iter().map(|r| (r.id, r)).collect(); + if let Some(broadcast_id) = self.broadcast_id { + let max_threads = self.ctx.get_settings().get_max_threads()? as usize; + // For distributed hash shuffle joins, convert bloom hashes into compact + // bloom filters before performing global merge so that only filters are + // transmitted between nodes. + convert_packet_bloom_hashes_to_filter(&mut packet, &runtime_filter_descs, max_threads)?; + packet = get_global_runtime_filter_packet(broadcast_id, packet, &self.ctx).await?; } - let runtime_filter_descs = self.filters_desc.iter().map(|r| (r.id, r)).collect(); let runtime_filter_infos = build_runtime_filter_infos( packet, runtime_filter_descs,