Skip to content

Commit 67f109a

Browse files
committed
fix: scan projection and joins
1 parent 0c25493 commit 67f109a

File tree

3 files changed

+195
-93
lines changed

3 files changed

+195
-93
lines changed

executor/src/pgscan.rs

Lines changed: 119 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::any::Any;
22
use std::collections::HashMap;
33
use std::pin::Pin;
44
use std::sync::{Arc, Mutex, MutexGuard};
5+
use std::sync::atomic::{AtomicU64, Ordering, AtomicU16};
56
use std::task::{Context, Poll};
67

78
use crate::shm;
@@ -13,7 +14,7 @@ use datafusion::arrow::array::{
1314
};
1415
use datafusion::arrow::datatypes::DataType as ArrowDataType;
1516
use datafusion::arrow::datatypes::SchemaRef;
16-
use datafusion::arrow::record_batch::RecordBatch;
17+
use datafusion::arrow::record_batch::{RecordBatch, RecordBatchOptions};
1718
use datafusion::common::Result as DFResult;
1819
use datafusion::error::DataFusionError;
1920
use datafusion::logical_expr::TableProviderFilterPushDown;
@@ -52,18 +53,25 @@ pub struct HeapBlock {
5253
pub struct ScanRegistry {
5354
inner: Mutex<HashMap<ScanId, Entry>>,
5455
conn_id: usize,
56+
next_id: AtomicU64,
57+
next_slot: AtomicU16,
5558
}
5659

5760
#[derive(Debug)]
5861
struct Entry {
5962
sender: mpsc::Sender<HeapBlock>,
6063
receiver: Option<mpsc::Receiver<HeapBlock>>,
64+
slot_id: u16,
6165
}
6266

6367
impl ScanRegistry {
64-
pub fn new() -> Self { Self { inner: Mutex::new(HashMap::new()), conn_id: 0 } }
68+
pub fn new() -> Self {
69+
Self { inner: Mutex::new(HashMap::new()), conn_id: 0, next_id: AtomicU64::new(1), next_slot: AtomicU16::new(0) }
70+
}
6571

66-
pub fn with_conn(conn_id: usize) -> Self { Self { inner: Mutex::new(HashMap::new()), conn_id } }
72+
pub fn with_conn(conn_id: usize) -> Self {
73+
Self { inner: Mutex::new(HashMap::new()), conn_id, next_id: AtomicU64::new(1), next_slot: AtomicU16::new(0) }
74+
}
6775

6876
#[inline]
6977
fn lock(&self) -> MutexGuard<HashMap<ScanId, Entry>> {
@@ -77,14 +85,22 @@ impl ScanRegistry {
7785
map.reserve(additional);
7886
}
7987

88+
/// Allocate a unique scan identifier for this connection.
89+
pub fn allocate_id(&self) -> ScanId {
90+
self.next_id.fetch_add(1, Ordering::Relaxed)
91+
}
92+
8093
pub fn register(&self, scan_id: ScanId, capacity: usize) -> mpsc::Sender<HeapBlock> {
94+
// Allocate alternating slot ids 0,1 for concurrent scans
95+
let slot = self.next_slot.fetch_add(1, Ordering::Relaxed) % 2;
8196
let (tx, rx) = mpsc::channel(capacity);
8297
let mut map = self.lock();
8398
map.insert(
8499
scan_id,
85100
Entry {
86101
sender: tx.clone(),
87102
receiver: Some(rx),
103+
slot_id: slot,
88104
},
89105
);
90106
tx
@@ -100,6 +116,12 @@ impl ScanRegistry {
100116
map.get_mut(&scan_id).and_then(|e| e.receiver.take())
101117
}
102118

119+
/// Get the assigned slot id for this scan
120+
pub fn slot_for(&self, scan_id: ScanId) -> Option<u16> {
121+
let map = self.lock();
122+
map.get(&scan_id).map(|e| e.slot_id)
123+
}
124+
103125
/// Close all channels and clear the registry.
104126
/// Dropping the senders will cause receivers to observe stream termination.
105127
pub fn close_and_clear(&self) {
@@ -122,15 +144,15 @@ impl ScanRegistry {
122144
#[derive(Debug)]
123145
pub struct PgTableProvider {
124146
schema: SchemaRef,
125-
scan_id: ScanId,
147+
table_oid: u32,
126148
registry: Arc<ScanRegistry>,
127149
}
128150

129151
impl PgTableProvider {
130-
pub fn new(scan_id: ScanId, schema: SchemaRef, registry: Arc<ScanRegistry>) -> Self {
152+
pub fn new(table_oid: u32, schema: SchemaRef, registry: Arc<ScanRegistry>) -> Self {
131153
Self {
132154
schema,
133-
scan_id,
155+
table_oid,
134156
registry,
135157
}
136158
}
@@ -164,9 +186,26 @@ impl TableProvider for PgTableProvider {
164186
_filters: &[Expr],
165187
_limit: Option<usize>,
166188
) -> DFResult<Arc<dyn ExecutionPlan>> {
167-
let exec = PgScanExec::new(
168-
self.scan_id,
169-
Arc::clone(&self.schema),
189+
// Respect requested projection; DataFusion expects the physical plan schema
190+
// to match the projected logical input schema.
191+
let proj_indices: Vec<usize> = match _projection {
192+
Some(ix) => ix.clone(),
193+
None => (0..self.schema.fields().len()).collect(),
194+
};
195+
let proj_fields: Vec<datafusion::arrow::datatypes::Field> = proj_indices
196+
.iter()
197+
.map(|&i| self.schema.field(i).clone())
198+
.collect();
199+
let proj_schema = Arc::new(datafusion::arrow::datatypes::Schema::new(proj_fields));
200+
201+
// Allocate a fresh, unique scan id for every scan instance (handles self-joins)
202+
let scan_id = self.registry.allocate_id();
203+
let exec = PgScanExec::with_projection(
204+
scan_id,
205+
self.table_oid,
206+
Arc::clone(&self.schema), // full schema for decoding
207+
proj_schema, // physical output schema
208+
proj_indices,
170209
Arc::clone(&self.registry),
171210
);
172211
Ok(Arc::new(exec))
@@ -175,31 +214,48 @@ impl TableProvider for PgTableProvider {
175214

176215
#[derive(Debug)]
177216
pub struct PgScanExec {
178-
schema: SchemaRef,
217+
// Full table schema (all columns) used to compute attribute metadata for decoding
218+
full_schema: SchemaRef,
219+
// Projected schema exposed by this plan node
220+
proj_schema: SchemaRef,
221+
// Postgres table OID for this scan
222+
table_oid: u32,
179223
scan_id: ScanId,
180224
registry: Arc<ScanRegistry>,
181225
props: PlanProperties,
182-
attrs: Arc<Vec<PgAttrMeta>>,
226+
// Attribute metadata for the full schema (decoder needs all columns to walk offsets)
227+
attrs_full: Arc<Vec<PgAttrMeta>>,
228+
// Indices of columns to project, in terms of full schema
229+
proj_indices: Arc<Vec<usize>>,
183230
}
184231

185232
impl PgScanExec {
186-
pub fn new(scan_id: ScanId, schema: SchemaRef, registry: Arc<ScanRegistry>) -> Self {
187-
let eq = EquivalenceProperties::new(schema.clone());
233+
pub fn with_projection(
234+
scan_id: ScanId,
235+
table_oid: u32,
236+
full_schema: SchemaRef,
237+
proj_schema: SchemaRef,
238+
proj_indices: Vec<usize>,
239+
registry: Arc<ScanRegistry>,
240+
) -> Self {
241+
// Execution plan schema must match the projected schema
242+
let eq = EquivalenceProperties::new(proj_schema.clone());
188243
let props = PlanProperties::new(
189244
eq,
190245
Partitioning::UnknownPartitioning(1),
191246
EmissionType::Incremental,
192-
Boundedness::Unbounded {
193-
requires_infinite_memory: false,
194-
},
247+
Boundedness::Bounded,
195248
);
196-
let attrs = Arc::new(attrs_from_schema(&schema));
249+
let attrs_full = Arc::new(attrs_from_schema(&full_schema));
197250
Self {
198-
schema,
251+
full_schema,
252+
proj_schema,
253+
table_oid,
199254
scan_id,
200255
registry,
201256
props,
202-
attrs,
257+
attrs_full,
258+
proj_indices: Arc::new(proj_indices),
203259
}
204260
}
205261
}
@@ -210,8 +266,8 @@ impl DisplayAs for PgScanExec {
210266
DisplayFormatType::Default => write!(f, "PgScanExec: scan_id={}", self.scan_id),
211267
DisplayFormatType::Verbose => write!(
212268
f,
213-
"PgScanExec: scan_id={}, schema={:?}",
214-
self.scan_id, self.schema
269+
"PgScanExec: scan_id={}, table_oid={}, proj_schema={:?}",
270+
self.scan_id, self.table_oid, self.proj_schema
215271
),
216272
}
217273
}
@@ -255,30 +311,41 @@ impl ExecutionPlan for PgScanExec {
255311
let rx = self.registry.take_receiver(self.scan_id);
256312
// Use connection id stored in registry to address per-connection slot buffers
257313
let conn_id = self.registry.conn_id();
258-
let stream = PgScanStream::new(Arc::clone(&self.schema), Arc::clone(&self.attrs), rx, conn_id);
314+
let stream = PgScanStream::new(
315+
Arc::clone(&self.proj_schema),
316+
Arc::clone(&self.attrs_full),
317+
Arc::clone(&self.proj_indices),
318+
rx,
319+
conn_id,
320+
);
259321
Ok(Box::pin(stream))
260322
}
261323

262324
fn statistics(&self) -> DFResult<Statistics> {
263-
Ok(Statistics::new_unknown(&self.schema))
325+
Ok(Statistics::new_unknown(&self.proj_schema))
264326
}
265327
}
266328

267329
pub struct PgScanStream {
268-
schema: SchemaRef,
269-
attrs: Arc<Vec<PgAttrMeta>>,
330+
// Schema of output batches
331+
proj_schema: SchemaRef,
332+
// Full attribute metadata for decoding
333+
attrs_full: Arc<Vec<PgAttrMeta>>,
334+
// Projection indices into full schema
335+
proj_indices: Arc<Vec<usize>>,
270336
rx: Option<mpsc::Receiver<HeapBlock>>,
271337
conn_id: usize,
272338
}
273339

274340
impl PgScanStream {
275341
pub fn new(
276-
schema: SchemaRef,
277-
attrs: Arc<Vec<PgAttrMeta>>,
342+
proj_schema: SchemaRef,
343+
attrs_full: Arc<Vec<PgAttrMeta>>,
344+
proj_indices: Arc<Vec<usize>>,
278345
rx: Option<mpsc::Receiver<HeapBlock>>,
279346
conn_id: usize,
280347
) -> Self {
281-
Self { schema, attrs, rx, conn_id }
348+
Self { proj_schema, attrs_full, proj_indices, rx, conn_id }
282349
}
283350
}
284351

@@ -301,19 +368,19 @@ impl Stream for PgScanStream {
301368
};
302369

303370
// Prepare decoding metadata: attrs (precomputed once) and projection (all columns)
304-
let total_cols = this.schema.fields().len();
371+
let total_cols = this.proj_schema.fields().len();
305372

306373
// Create a HeapPage view and iterate tuples
307374
let hp = unsafe { HeapPage::from_slice(page) };
308375
let Ok(hp) = hp else {
309376
// On error decoding page, return empty batch for resilience
310-
let batch = RecordBatch::new_empty(Arc::clone(&this.schema));
377+
let batch = RecordBatch::new_empty(Arc::clone(&this.proj_schema));
311378
return Poll::Ready(Some(Ok(batch)));
312379
};
313380

314381
// Prepare per-column builders
315382
let col_count = total_cols;
316-
let mut builders = make_builders(&this.schema, block.num_offsets as usize)
383+
let mut builders = make_builders(&this.proj_schema, block.num_offsets as usize)
317384
.map_err(|e| datafusion::error::DataFusionError::Execution(format!("{e}")))?;
318385
// Use tuples_by_offset to iterate LP_NORMAL tuples in page order
319386
let mut pairs: Vec<(u16, u16)> = Vec::new();
@@ -335,12 +402,19 @@ impl Stream for PgScanStream {
335402
as *const pg_sys::PageHeaderData;
336403
let mut decoded_rows = 0usize;
337404
while let Some(tup) = it.next() {
338-
// Decode projected columns for tuple using iterator over all columns
339-
let iter =
340-
unsafe { decode_tuple_project(page_hdr, tup, &this.attrs, 0..total_cols) };
405+
// Decode projected columns for tuple using iterator over requested projection
406+
let iter = unsafe {
407+
decode_tuple_project(
408+
page_hdr,
409+
tup,
410+
&this.attrs_full,
411+
this.proj_indices.iter().copied(),
412+
)
413+
};
341414
let Ok(mut iter) = iter else {
342415
continue;
343416
};
417+
// Iterate over projected columns in order
344418
for col_idx in 0..total_cols {
345419
match iter.next() {
346420
Some(Ok(v)) => append_scalar(&mut builders[col_idx], v),
@@ -356,8 +430,15 @@ impl Stream for PgScanStream {
356430
for b in builders.into_iter() {
357431
arrs.push(finish_builder(b));
358432
}
359-
let rb = RecordBatch::try_new(Arc::clone(&this.schema), arrs)
360-
.map_err(|e| datafusion::error::DataFusionError::Execution(format!("{e}")))?;
433+
let rb = if this.proj_schema.fields().is_empty() {
434+
// Special case: empty projection — use row_count to communicate the number of rows
435+
let opts = RecordBatchOptions::new().with_row_count(Some(decoded_rows));
436+
RecordBatch::try_new_with_options(Arc::clone(&this.proj_schema), vec![], &opts)
437+
.map_err(|e| datafusion::error::DataFusionError::Execution(format!("{e}")))?
438+
} else {
439+
RecordBatch::try_new(Arc::clone(&this.proj_schema), arrs)
440+
.map_err(|e| datafusion::error::DataFusionError::Execution(format!("{e}")))?
441+
};
361442
tracing::trace!(
362443
target = "executor::server",
363444
rows = decoded_rows,
@@ -561,7 +642,7 @@ fn finish_builder(b: ColBuilder) -> ArrayRef {
561642

562643
impl RecordBatchStream for PgScanStream {
563644
fn schema(&self) -> SchemaRef {
564-
Arc::clone(&self.schema)
645+
Arc::clone(&self.proj_schema)
565646
}
566647
}
567648

@@ -589,7 +670,7 @@ where
589670
{
590671
if let Some(p) = node.as_any().downcast_ref::<PgScanExec>() {
591672
let id = p.scan_id;
592-
let table_oid = id as u32; // current convention
673+
let table_oid = p.table_oid;
593674
f(id, table_oid)?;
594675
}
595676
for child in node.children() {

0 commit comments

Comments
 (0)