Skip to content

Commit 9329deb

Browse files
committed
perf(trie): batch storage proof jobs at worker level
Implement worker-level batching for storage proofs to reduce redundant trie traversals when multiple proof requests arrive for the same account. When storage proof requests queue up faster than workers can process them, jobs for the same account are now merged into a single proof computation. This reduces trie I/O and computation overhead significantly.
1 parent ba58b84 commit 9329deb

File tree

1 file changed

+249
-54
lines changed

1 file changed

+249
-54
lines changed

crates/trie/parallel/src/proof_task.rs

Lines changed: 249 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use alloy_primitives::{
4141
use alloy_rlp::{BufMut, Encodable};
4242
use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
4343
use dashmap::DashMap;
44+
use metrics::Histogram;
4445
use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
4546
use reth_provider::{DatabaseProviderROFactory, ProviderError, ProviderResult};
4647
use reth_storage_errors::db::DatabaseError;
@@ -79,6 +80,93 @@ use crate::proof_task_metrics::{
7980
type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
8081
type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
8182

83+
/// Maximum number of storage proof jobs to batch together per account.
84+
const STORAGE_PROOF_BATCH_LIMIT: usize = 32;
85+
86+
/// Holds batched storage proof jobs for the same account.
87+
///
88+
/// When multiple storage proof requests arrive for the same account, they can be merged
89+
/// into a single proof computation with combined prefix sets and target slots.
90+
#[derive(Debug)]
91+
struct BatchedStorageProof {
92+
/// The merged prefix set from all batched jobs.
93+
prefix_set: PrefixSetMut,
94+
/// The merged target slots from all batched jobs.
95+
target_slots: B256Set,
96+
/// Whether any job requested branch node masks.
97+
with_branch_node_masks: bool,
98+
/// The `multi_added_removed_keys` from the first job (they should all share the same `Arc`).
99+
multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
100+
/// All senders that need to receive the result.
101+
senders: Vec<ProofResultContext>,
102+
}
103+
104+
impl BatchedStorageProof {
105+
/// Creates a new batch from the first storage proof input.
106+
fn new(input: StorageProofInput, sender: ProofResultContext) -> Self {
107+
// Convert frozen PrefixSet to mutable PrefixSetMut by collecting its keys.
108+
let prefix_set = PrefixSetMut::from(input.prefix_set.iter().copied());
109+
Self {
110+
prefix_set,
111+
target_slots: input.target_slots,
112+
with_branch_node_masks: input.with_branch_node_masks,
113+
multi_added_removed_keys: input.multi_added_removed_keys,
114+
senders: vec![sender],
115+
}
116+
}
117+
118+
/// Merges another storage proof job into this batch.
119+
fn merge(&mut self, input: StorageProofInput, sender: ProofResultContext) {
120+
self.prefix_set.extend_keys(input.prefix_set.iter().copied());
121+
self.target_slots.extend(input.target_slots);
122+
self.with_branch_node_masks |= input.with_branch_node_masks;
123+
self.senders.push(sender);
124+
}
125+
126+
/// Converts this batch into a single `StorageProofInput` for computation.
127+
fn into_input(self, hashed_address: B256) -> (StorageProofInput, Vec<ProofResultContext>) {
128+
let input = StorageProofInput {
129+
hashed_address,
130+
prefix_set: self.prefix_set.freeze(),
131+
target_slots: self.target_slots,
132+
with_branch_node_masks: self.with_branch_node_masks,
133+
multi_added_removed_keys: self.multi_added_removed_keys,
134+
};
135+
(input, self.senders)
136+
}
137+
}
138+
139+
/// Metrics for storage worker batching.
140+
#[derive(Clone, Default)]
141+
struct StorageWorkerBatchMetrics {
142+
/// Histogram of batch sizes (number of jobs merged per computation).
143+
#[cfg(feature = "metrics")]
144+
batch_size_histogram: Option<Histogram>,
145+
}
146+
147+
impl StorageWorkerBatchMetrics {
148+
#[cfg(feature = "metrics")]
149+
fn new() -> Self {
150+
Self {
151+
batch_size_histogram: Some(metrics::histogram!(
152+
"trie.proof_task.storage_worker_batch_size"
153+
)),
154+
}
155+
}
156+
157+
#[cfg(not(feature = "metrics"))]
158+
fn new() -> Self {
159+
Self {}
160+
}
161+
162+
fn record_batch_size(&self, _size: usize) {
163+
#[cfg(feature = "metrics")]
164+
if let Some(h) = &self.batch_size_histogram {
165+
h.record(_size as f64);
166+
}
167+
}
168+
}
169+
82170
/// A handle that provides type-safe access to proof worker pools.
83171
///
84172
/// The handle stores direct senders to both storage and account worker pools,
@@ -552,7 +640,7 @@ impl TrieNodeProvider for ProofTaskTrieNodeProvider {
552640
}
553641
}
554642
/// Result of a proof calculation, which can be either an account multiproof or a storage proof.
555-
#[derive(Debug)]
643+
#[derive(Debug, Clone)]
556644
pub enum ProofResult {
557645
/// Account multiproof with statistics
558646
AccountMultiproof {
@@ -708,11 +796,18 @@ where
708796
/// 2. Advertises availability
709797
/// 3. Processes jobs in a loop:
710798
/// - Receives job from channel
799+
/// - Drains additional same-account storage proof jobs (batching)
711800
/// - Marks worker as busy
712-
/// - Processes the job
801+
/// - Processes the batched jobs as a single proof computation
713802
/// - Marks worker as available
714803
/// 4. Shuts down when channel closes
715804
///
805+
/// # Batching Strategy
806+
///
807+
/// When multiple storage proof requests arrive for the same account, they are merged
808+
/// into a single proof computation. This reduces redundant trie traversals when state
809+
/// updates arrive faster than proof computation can process them.
810+
///
716811
/// # Panic Safety
717812
///
718813
/// If this function panics, the worker thread terminates but other workers
@@ -732,6 +827,7 @@ where
732827
// Create provider from factory
733828
let provider = task_ctx.factory.database_provider_ro()?;
734829
let proof_tx = ProofTaskTx::new(provider, worker_id);
830+
let batch_metrics = StorageWorkerBatchMetrics::new();
735831

736832
trace!(
737833
target: "trie::proof_task",
@@ -746,20 +842,98 @@ where
746842
// Initially mark this worker as available.
747843
available_workers.fetch_add(1, Ordering::Relaxed);
748844

845+
// Deferred blinded node jobs to process after batched storage proofs.
846+
let mut deferred_blinded_nodes: Vec<(B256, Nibbles, Sender<TrieNodeProviderResult>)> =
847+
Vec::new();
848+
749849
while let Ok(job) = work_rx.recv() {
750850
// Mark worker as busy.
751851
available_workers.fetch_sub(1, Ordering::Relaxed);
752852

753853
match job {
754854
StorageWorkerJob::StorageProof { input, proof_result_sender } => {
755-
Self::process_storage_proof(
756-
worker_id,
757-
&proof_tx,
758-
input,
759-
proof_result_sender,
760-
&mut storage_proofs_processed,
761-
&mut cursor_metrics_cache,
855+
// Start batching: group storage proofs by account.
856+
let mut batches: B256Map<BatchedStorageProof> = B256Map::default();
857+
batches.insert(
858+
input.hashed_address,
859+
BatchedStorageProof::new(input, proof_result_sender),
762860
);
861+
let mut total_jobs = 1usize;
862+
863+
// Drain additional jobs from the queue.
864+
while total_jobs < STORAGE_PROOF_BATCH_LIMIT {
865+
match work_rx.try_recv() {
866+
Ok(StorageWorkerJob::StorageProof {
867+
input: next_input,
868+
proof_result_sender: next_sender,
869+
}) => {
870+
total_jobs += 1;
871+
let addr = next_input.hashed_address;
872+
match batches.entry(addr) {
873+
alloy_primitives::map::Entry::Occupied(mut entry) => {
874+
entry.get_mut().merge(next_input, next_sender);
875+
}
876+
alloy_primitives::map::Entry::Vacant(entry) => {
877+
entry.insert(BatchedStorageProof::new(
878+
next_input,
879+
next_sender,
880+
));
881+
}
882+
}
883+
}
884+
Ok(StorageWorkerJob::BlindedStorageNode {
885+
account,
886+
path,
887+
result_sender,
888+
}) => {
889+
// Defer blinded node jobs to process after batched proofs.
890+
deferred_blinded_nodes.push((account, path, result_sender));
891+
}
892+
Err(_) => break,
893+
}
894+
}
895+
896+
// Process all batched storage proofs.
897+
for (hashed_address, batch) in batches {
898+
let batch_size = batch.senders.len();
899+
batch_metrics.record_batch_size(batch_size);
900+
901+
let (merged_input, senders) = batch.into_input(hashed_address);
902+
903+
trace!(
904+
target: "trie::proof_task",
905+
worker_id,
906+
?hashed_address,
907+
batch_size,
908+
prefix_set_len = merged_input.prefix_set.len(),
909+
target_slots_len = merged_input.target_slots.len(),
910+
"Processing batched storage proof"
911+
);
912+
913+
Self::process_batched_storage_proof(
914+
worker_id,
915+
&proof_tx,
916+
hashed_address,
917+
merged_input,
918+
senders,
919+
&mut storage_proofs_processed,
920+
&mut cursor_metrics_cache,
921+
);
922+
}
923+
924+
// Process any deferred blinded node jobs.
925+
for (account, path, result_sender) in
926+
std::mem::take(&mut deferred_blinded_nodes)
927+
{
928+
Self::process_blinded_node(
929+
worker_id,
930+
&proof_tx,
931+
account,
932+
path,
933+
result_sender,
934+
&mut storage_nodes_processed,
935+
);
936+
}
763937
}
764938

765939
StorageWorkerJob::BlindedStorageNode { account, path, result_sender } => {
@@ -795,82 +969,103 @@ where
795969
Ok(())
796970
}
797971

798-
/// Processes a storage proof request.
799-
fn process_storage_proof<Provider>(
972+
/// Processes a batched storage proof request and sends results to all waiting receivers.
973+
///
974+
/// This computes a single storage proof with merged targets and sends the same result
975+
/// to all original requestors, reducing redundant trie traversals.
976+
fn process_batched_storage_proof<Provider>(
800977
worker_id: usize,
801978
proof_tx: &ProofTaskTx<Provider>,
979+
hashed_address: B256,
802980
input: StorageProofInput,
803-
proof_result_sender: ProofResultContext,
981+
senders: Vec<ProofResultContext>,
804982
storage_proofs_processed: &mut u64,
805983
cursor_metrics_cache: &mut ProofTaskCursorMetricsCache,
806984
) where
807985
Provider: TrieCursorFactory + HashedCursorFactory,
808986
{
809-
let hashed_address = input.hashed_address;
810-
let ProofResultContext { sender, sequence_number: seq, state, start_time } =
811-
proof_result_sender;
812-
813987
let mut trie_cursor_metrics = TrieCursorMetricsCache::default();
814988
let mut hashed_cursor_metrics = HashedCursorMetricsCache::default();
815989

816-
trace!(
817-
target: "trie::proof_task",
818-
worker_id,
819-
hashed_address = ?hashed_address,
820-
prefix_set_len = input.prefix_set.len(),
821-
target_slots_len = input.target_slots.len(),
822-
"Processing storage proof"
823-
);
824-
825990
let proof_start = Instant::now();
826991
let result = proof_tx.compute_storage_proof(
827992
input,
828993
&mut trie_cursor_metrics,
829994
&mut hashed_cursor_metrics,
830995
);
831-
832996
let proof_elapsed = proof_start.elapsed();
833-
*storage_proofs_processed += 1;
834-
835-
let result_msg = result.map(|storage_proof| ProofResult::StorageProof {
836-
hashed_address,
837-
proof: storage_proof,
838-
});
839997

840-
if sender
841-
.send(ProofResultMessage {
842-
sequence_number: seq,
843-
result: result_msg,
844-
elapsed: start_time.elapsed(),
845-
state,
846-
})
847-
.is_err()
848-
{
849-
trace!(
850-
target: "trie::proof_task",
851-
worker_id,
852-
hashed_address = ?hashed_address,
853-
storage_proofs_processed,
854-
"Proof result receiver dropped, discarding result"
855-
);
998+
// Send the result to all waiting receivers.
999+
let num_senders = senders.len();
1000+
match result {
1001+
Ok(storage_proof) => {
1002+
// Success case: clone the proof for each sender.
1003+
let proof_result =
1004+
ProofResult::StorageProof { hashed_address, proof: storage_proof };
1005+
1006+
for ProofResultContext { sender, sequence_number, state, start_time } in senders {
1007+
*storage_proofs_processed += 1;
1008+
1009+
if sender
1010+
.send(ProofResultMessage {
1011+
sequence_number,
1012+
result: Ok(proof_result.clone()),
1013+
elapsed: start_time.elapsed(),
1014+
state,
1015+
})
1016+
.is_err()
1017+
{
1018+
trace!(
1019+
target: "trie::proof_task",
1020+
worker_id,
1021+
?hashed_address,
1022+
sequence_number,
1023+
"Proof result receiver dropped, discarding result"
1024+
);
1025+
}
1026+
}
1027+
}
1028+
Err(error) => {
1029+
// Error case: convert to string for cloning, then send to all receivers.
1030+
let error_msg = error.to_string();
1031+
1032+
for ProofResultContext { sender, sequence_number, state, start_time } in senders {
1033+
*storage_proofs_processed += 1;
1034+
1035+
if sender
1036+
.send(ProofResultMessage {
1037+
sequence_number,
1038+
result: Err(ParallelStateRootError::Other(error_msg.clone())),
1039+
elapsed: start_time.elapsed(),
1040+
state,
1041+
})
1042+
.is_err()
1043+
{
1044+
trace!(
1045+
target: "trie::proof_task",
1046+
worker_id,
1047+
?hashed_address,
1048+
sequence_number,
1049+
"Proof result receiver dropped, discarding result"
1050+
);
1051+
}
1052+
}
1053+
}
8561054
}
8571055

8581056
trace!(
8591057
target: "trie::proof_task",
8601058
worker_id,
861-
hashed_address = ?hashed_address,
1059+
?hashed_address,
8621060
proof_time_us = proof_elapsed.as_micros(),
863-
total_processed = storage_proofs_processed,
1061+
num_senders,
8641062
trie_cursor_duration_us = trie_cursor_metrics.total_duration.as_micros(),
8651063
hashed_cursor_duration_us = hashed_cursor_metrics.total_duration.as_micros(),
866-
?trie_cursor_metrics,
867-
?hashed_cursor_metrics,
868-
"Storage proof completed"
1064+
"Batched storage proof completed"
8691065
);
8701066

8711067
#[cfg(feature = "metrics")]
8721068
{
873-
// Accumulate per-proof metrics into the worker's cache
8741069
let per_proof_cache = ProofTaskCursorMetricsCache {
8751070
account_trie_cursor: TrieCursorMetricsCache::default(),
8761071
account_hashed_cursor: HashedCursorMetricsCache::default(),

0 commit comments

Comments
 (0)