@@ -41,6 +41,7 @@ use alloy_primitives::{
4141use alloy_rlp:: { BufMut , Encodable } ;
4242use crossbeam_channel:: { unbounded, Receiver as CrossbeamReceiver , Sender as CrossbeamSender } ;
4343use dashmap:: DashMap ;
44+ use metrics:: Histogram ;
4445use reth_execution_errors:: { SparseTrieError , SparseTrieErrorKind } ;
4546use reth_provider:: { DatabaseProviderROFactory , ProviderError , ProviderResult } ;
4647use reth_storage_errors:: db:: DatabaseError ;
@@ -79,6 +80,93 @@ use crate::proof_task_metrics::{
7980type StorageProofResult = Result < DecodedStorageMultiProof , ParallelStateRootError > ;
8081type TrieNodeProviderResult = Result < Option < RevealedNode > , SparseTrieError > ;
8182
83+ /// Maximum number of storage proof jobs to batch together per account.
84+ const STORAGE_PROOF_BATCH_LIMIT : usize = 32 ;
85+
86+ /// Holds batched storage proof jobs for the same account.
87+ ///
88+ /// When multiple storage proof requests arrive for the same account, they can be merged
89+ /// into a single proof computation with combined prefix sets and target slots.
90+ #[ derive( Debug ) ]
91+ struct BatchedStorageProof {
92+ /// The merged prefix set from all batched jobs.
93+ prefix_set : PrefixSetMut ,
94+ /// The merged target slots from all batched jobs.
95+ target_slots : B256Set ,
96+ /// Whether any job requested branch node masks.
97+ with_branch_node_masks : bool ,
98+ /// The `multi_added_removed_keys` from the first job (they should all share the same `Arc`).
99+ multi_added_removed_keys : Option < Arc < MultiAddedRemovedKeys > > ,
100+ /// All senders that need to receive the result.
101+ senders : Vec < ProofResultContext > ,
102+ }
103+
104+ impl BatchedStorageProof {
105+ /// Creates a new batch from the first storage proof input.
106+ fn new ( input : StorageProofInput , sender : ProofResultContext ) -> Self {
107+ // Convert frozen PrefixSet to mutable PrefixSetMut by collecting its keys.
108+ let prefix_set = PrefixSetMut :: from ( input. prefix_set . iter ( ) . copied ( ) ) ;
109+ Self {
110+ prefix_set,
111+ target_slots : input. target_slots ,
112+ with_branch_node_masks : input. with_branch_node_masks ,
113+ multi_added_removed_keys : input. multi_added_removed_keys ,
114+ senders : vec ! [ sender] ,
115+ }
116+ }
117+
118+ /// Merges another storage proof job into this batch.
119+ fn merge ( & mut self , input : StorageProofInput , sender : ProofResultContext ) {
120+ self . prefix_set . extend_keys ( input. prefix_set . iter ( ) . copied ( ) ) ;
121+ self . target_slots . extend ( input. target_slots ) ;
122+ self . with_branch_node_masks |= input. with_branch_node_masks ;
123+ self . senders . push ( sender) ;
124+ }
125+
126+ /// Converts this batch into a single `StorageProofInput` for computation.
127+ fn into_input ( self , hashed_address : B256 ) -> ( StorageProofInput , Vec < ProofResultContext > ) {
128+ let input = StorageProofInput {
129+ hashed_address,
130+ prefix_set : self . prefix_set . freeze ( ) ,
131+ target_slots : self . target_slots ,
132+ with_branch_node_masks : self . with_branch_node_masks ,
133+ multi_added_removed_keys : self . multi_added_removed_keys ,
134+ } ;
135+ ( input, self . senders )
136+ }
137+ }
138+
139+ /// Metrics for storage worker batching.
140+ #[ derive( Clone , Default ) ]
141+ struct StorageWorkerBatchMetrics {
142+ /// Histogram of batch sizes (number of jobs merged per computation).
143+ #[ cfg( feature = "metrics" ) ]
144+ batch_size_histogram : Option < Histogram > ,
145+ }
146+
147+ impl StorageWorkerBatchMetrics {
148+ #[ cfg( feature = "metrics" ) ]
149+ fn new ( ) -> Self {
150+ Self {
151+ batch_size_histogram : Some ( metrics:: histogram!(
152+ "trie.proof_task.storage_worker_batch_size"
153+ ) ) ,
154+ }
155+ }
156+
157+ #[ cfg( not( feature = "metrics" ) ) ]
158+ fn new ( ) -> Self {
159+ Self { }
160+ }
161+
162+ fn record_batch_size ( & self , _size : usize ) {
163+ #[ cfg( feature = "metrics" ) ]
164+ if let Some ( h) = & self . batch_size_histogram {
165+ h. record ( _size as f64 ) ;
166+ }
167+ }
168+ }
169+
82170/// A handle that provides type-safe access to proof worker pools.
83171///
84172/// The handle stores direct senders to both storage and account worker pools,
@@ -552,7 +640,7 @@ impl TrieNodeProvider for ProofTaskTrieNodeProvider {
552640 }
553641}
554642/// Result of a proof calculation, which can be either an account multiproof or a storage proof.
555- #[ derive( Debug ) ]
643+ #[ derive( Debug , Clone ) ]
556644pub enum ProofResult {
557645 /// Account multiproof with statistics
558646 AccountMultiproof {
@@ -708,11 +796,18 @@ where
708796 /// 2. Advertises availability
709797 /// 3. Processes jobs in a loop:
710798 /// - Receives job from channel
799+ /// - Drains additional same-account storage proof jobs (batching)
711800 /// - Marks worker as busy
712- /// - Processes the job
801+ /// - Processes the batched jobs as a single proof computation
713802 /// - Marks worker as available
714803 /// 4. Shuts down when channel closes
715804 ///
805+ /// # Batching Strategy
806+ ///
807+ /// When multiple storage proof requests arrive for the same account, they are merged
808+ /// into a single proof computation. This reduces redundant trie traversals when state
809+ /// updates arrive faster than proof computation can process them.
810+ ///
716811 /// # Panic Safety
717812 ///
718813 /// If this function panics, the worker thread terminates but other workers
@@ -732,6 +827,7 @@ where
732827 // Create provider from factory
733828 let provider = task_ctx. factory . database_provider_ro ( ) ?;
734829 let proof_tx = ProofTaskTx :: new ( provider, worker_id) ;
830+ let batch_metrics = StorageWorkerBatchMetrics :: new ( ) ;
735831
736832 trace ! (
737833 target: "trie::proof_task" ,
@@ -746,20 +842,98 @@ where
746842 // Initially mark this worker as available.
747843 available_workers. fetch_add ( 1 , Ordering :: Relaxed ) ;
748844
845+ // Deferred blinded node jobs to process after batched storage proofs.
846+ let mut deferred_blinded_nodes: Vec < ( B256 , Nibbles , Sender < TrieNodeProviderResult > ) > =
847+ Vec :: new ( ) ;
848+
749849 while let Ok ( job) = work_rx. recv ( ) {
750850 // Mark worker as busy.
751851 available_workers. fetch_sub ( 1 , Ordering :: Relaxed ) ;
752852
753853 match job {
754854 StorageWorkerJob :: StorageProof { input, proof_result_sender } => {
755- Self :: process_storage_proof (
756- worker_id,
757- & proof_tx,
758- input,
759- proof_result_sender,
760- & mut storage_proofs_processed,
761- & mut cursor_metrics_cache,
855+ // Start batching: group storage proofs by account.
856+ let mut batches: B256Map < BatchedStorageProof > = B256Map :: default ( ) ;
857+ batches. insert (
858+ input. hashed_address ,
859+ BatchedStorageProof :: new ( input, proof_result_sender) ,
762860 ) ;
861+ let mut total_jobs = 1usize ;
862+
863+ // Drain additional jobs from the queue.
864+ while total_jobs < STORAGE_PROOF_BATCH_LIMIT {
865+ match work_rx. try_recv ( ) {
866+ Ok ( StorageWorkerJob :: StorageProof {
867+ input : next_input,
868+ proof_result_sender : next_sender,
869+ } ) => {
870+ total_jobs += 1 ;
871+ let addr = next_input. hashed_address ;
872+ match batches. entry ( addr) {
873+ alloy_primitives:: map:: Entry :: Occupied ( mut entry) => {
874+ entry. get_mut ( ) . merge ( next_input, next_sender) ;
875+ }
876+ alloy_primitives:: map:: Entry :: Vacant ( entry) => {
877+ entry. insert ( BatchedStorageProof :: new (
878+ next_input,
879+ next_sender,
880+ ) ) ;
881+ }
882+ }
883+ }
884+ Ok ( StorageWorkerJob :: BlindedStorageNode {
885+ account,
886+ path,
887+ result_sender,
888+ } ) => {
889+ // Defer blinded node jobs to process after batched proofs.
890+ deferred_blinded_nodes. push ( ( account, path, result_sender) ) ;
891+ }
892+ Err ( _) => break ,
893+ }
894+ }
895+
896+ // Process all batched storage proofs.
897+ for ( hashed_address, batch) in batches {
898+ let batch_size = batch. senders . len ( ) ;
899+ batch_metrics. record_batch_size ( batch_size) ;
900+
901+ let ( merged_input, senders) = batch. into_input ( hashed_address) ;
902+
903+ trace ! (
904+ target: "trie::proof_task" ,
905+ worker_id,
906+ ?hashed_address,
907+ batch_size,
908+ prefix_set_len = merged_input. prefix_set. len( ) ,
909+ target_slots_len = merged_input. target_slots. len( ) ,
910+ "Processing batched storage proof"
911+ ) ;
912+
913+ Self :: process_batched_storage_proof (
914+ worker_id,
915+ & proof_tx,
916+ hashed_address,
917+ merged_input,
918+ senders,
919+ & mut storage_proofs_processed,
920+ & mut cursor_metrics_cache,
921+ ) ;
922+ }
923+
924+ // Process any deferred blinded node jobs.
925+ for ( account, path, result_sender) in
926+ std:: mem:: take ( & mut deferred_blinded_nodes)
927+ {
928+ Self :: process_blinded_node (
929+ worker_id,
930+ & proof_tx,
931+ account,
932+ path,
933+ result_sender,
934+ & mut storage_nodes_processed,
935+ ) ;
936+ }
763937 }
764938
765939 StorageWorkerJob :: BlindedStorageNode { account, path, result_sender } => {
@@ -795,82 +969,103 @@ where
795969 Ok ( ( ) )
796970 }
797971
798- /// Processes a storage proof request.
799- fn process_storage_proof < Provider > (
972+ /// Processes a batched storage proof request and sends results to all waiting receivers.
973+ ///
974+ /// This computes a single storage proof with merged targets and sends the same result
975+ /// to all original requestors, reducing redundant trie traversals.
976+ fn process_batched_storage_proof < Provider > (
800977 worker_id : usize ,
801978 proof_tx : & ProofTaskTx < Provider > ,
979+ hashed_address : B256 ,
802980 input : StorageProofInput ,
803- proof_result_sender : ProofResultContext ,
981+ senders : Vec < ProofResultContext > ,
804982 storage_proofs_processed : & mut u64 ,
805983 cursor_metrics_cache : & mut ProofTaskCursorMetricsCache ,
806984 ) where
807985 Provider : TrieCursorFactory + HashedCursorFactory ,
808986 {
809- let hashed_address = input. hashed_address ;
810- let ProofResultContext { sender, sequence_number : seq, state, start_time } =
811- proof_result_sender;
812-
813987 let mut trie_cursor_metrics = TrieCursorMetricsCache :: default ( ) ;
814988 let mut hashed_cursor_metrics = HashedCursorMetricsCache :: default ( ) ;
815989
816- trace ! (
817- target: "trie::proof_task" ,
818- worker_id,
819- hashed_address = ?hashed_address,
820- prefix_set_len = input. prefix_set. len( ) ,
821- target_slots_len = input. target_slots. len( ) ,
822- "Processing storage proof"
823- ) ;
824-
825990 let proof_start = Instant :: now ( ) ;
826991 let result = proof_tx. compute_storage_proof (
827992 input,
828993 & mut trie_cursor_metrics,
829994 & mut hashed_cursor_metrics,
830995 ) ;
831-
832996 let proof_elapsed = proof_start. elapsed ( ) ;
833- * storage_proofs_processed += 1 ;
834-
835- let result_msg = result. map ( |storage_proof| ProofResult :: StorageProof {
836- hashed_address,
837- proof : storage_proof,
838- } ) ;
839997
840- if sender
841- . send ( ProofResultMessage {
842- sequence_number : seq,
843- result : result_msg,
844- elapsed : start_time. elapsed ( ) ,
845- state,
846- } )
847- . is_err ( )
848- {
849- trace ! (
850- target: "trie::proof_task" ,
851- worker_id,
852- hashed_address = ?hashed_address,
853- storage_proofs_processed,
854- "Proof result receiver dropped, discarding result"
855- ) ;
998+ // Send the result to all waiting receivers.
999+ let num_senders = senders. len ( ) ;
1000+ match result {
1001+ Ok ( storage_proof) => {
1002+ // Success case: clone the proof for each sender.
1003+ let proof_result =
1004+ ProofResult :: StorageProof { hashed_address, proof : storage_proof } ;
1005+
1006+ for ProofResultContext { sender, sequence_number, state, start_time } in senders {
1007+ * storage_proofs_processed += 1 ;
1008+
1009+ if sender
1010+ . send ( ProofResultMessage {
1011+ sequence_number,
1012+ result : Ok ( proof_result. clone ( ) ) ,
1013+ elapsed : start_time. elapsed ( ) ,
1014+ state,
1015+ } )
1016+ . is_err ( )
1017+ {
1018+ trace ! (
1019+ target: "trie::proof_task" ,
1020+ worker_id,
1021+ ?hashed_address,
1022+ sequence_number,
1023+ "Proof result receiver dropped, discarding result"
1024+ ) ;
1025+ }
1026+ }
1027+ }
1028+ Err ( error) => {
1029+ // Error case: convert to string for cloning, then send to all receivers.
1030+ let error_msg = error. to_string ( ) ;
1031+
1032+ for ProofResultContext { sender, sequence_number, state, start_time } in senders {
1033+ * storage_proofs_processed += 1 ;
1034+
1035+ if sender
1036+ . send ( ProofResultMessage {
1037+ sequence_number,
1038+ result : Err ( ParallelStateRootError :: Other ( error_msg. clone ( ) ) ) ,
1039+ elapsed : start_time. elapsed ( ) ,
1040+ state,
1041+ } )
1042+ . is_err ( )
1043+ {
1044+ trace ! (
1045+ target: "trie::proof_task" ,
1046+ worker_id,
1047+ ?hashed_address,
1048+ sequence_number,
1049+ "Proof result receiver dropped, discarding result"
1050+ ) ;
1051+ }
1052+ }
1053+ }
8561054 }
8571055
8581056 trace ! (
8591057 target: "trie::proof_task" ,
8601058 worker_id,
861- hashed_address = ?hashed_address,
1059+ ?hashed_address,
8621060 proof_time_us = proof_elapsed. as_micros( ) ,
863- total_processed = storage_proofs_processed ,
1061+ num_senders ,
8641062 trie_cursor_duration_us = trie_cursor_metrics. total_duration. as_micros( ) ,
8651063 hashed_cursor_duration_us = hashed_cursor_metrics. total_duration. as_micros( ) ,
866- ?trie_cursor_metrics,
867- ?hashed_cursor_metrics,
868- "Storage proof completed"
1064+ "Batched storage proof completed"
8691065 ) ;
8701066
8711067 #[ cfg( feature = "metrics" ) ]
8721068 {
873- // Accumulate per-proof metrics into the worker's cache
8741069 let per_proof_cache = ProofTaskCursorMetricsCache {
8751070 account_trie_cursor : TrieCursorMetricsCache :: default ( ) ,
8761071 account_hashed_cursor : HashedCursorMetricsCache :: default ( ) ,
0 commit comments