@@ -421,7 +421,7 @@ def prepare_lora_batch(self, model_worker_batch: ModelWorkerBatch):
421421 assert len (cur_uids ) <= self .max_loras_per_batch
422422
423423 # Load adapters into device memory pool (CPU -> device transfer)
424- self .memory_pool .prepare_lora_batch (
424+ has_new_weights = self .memory_pool .prepare_lora_batch (
425425 cur_uids = cur_uids ,
426426 lora_adapters = self .loras ,
427427 )
@@ -437,20 +437,27 @@ def prepare_lora_batch(self, model_worker_batch: ModelWorkerBatch):
437437 lora_ranks [weight_indices [i ]] = lora .config .r
438438 scalings [weight_indices [i ]] = lora .scaling
439439
440- self .lora_backend .prepare_lora_batch (
440+ batch_info = self .lora_backend .prepare_lora_batch (
441441 model_worker_batch = model_worker_batch ,
442442 weight_indices = weight_indices ,
443443 lora_ranks = lora_ranks ,
444444 scalings = scalings ,
445445 )
446+ model_worker_batch .lora_batch_info = batch_info
446447
447448 # Update LoRA layer buffer references after loading new weights
448449 # This is necessary because JAX arrays are immutable, and load_lora_weight_to_buffer
449450 # creates new arrays. We need to update the references in LoRALinear layers.
450- self .update_lora_info ()
451+ if has_new_weights :
452+ self .update_lora_info ()
451453
452454 logger .debug ("Prepared LoRA batch: %d unique adapters" , len (cur_uids ))
453455
456+ def set_batch_info (self , batch_info ):
457+ """Set batch info in backend."""
458+ if hasattr (self , "lora_backend" ):
459+ self .lora_backend .set_batch_info (batch_info )
460+
454461 def get_buffer_id (self , lora_id : str | None ) -> int :
455462 """Get buffer slot ID for a given LoRA adapter ID."""
456463 return self .memory_pool .get_buffer_id (lora_id )
0 commit comments