@@ -111,6 +111,7 @@ def __init__(
111111 self .num_attention_heads = base_hf_config .num_attention_heads
112112 self .num_kv_heads = getattr (base_hf_config , "num_key_value_heads" , self .num_attention_heads )
113113 self .head_dim = getattr (base_hf_config , "head_dim" , None )
114+ self .static_lora = server_args .enable_static_lora
114115
115116 # Get original num_kv_heads and tp_size for replication
116117 if model_config is not None :
@@ -420,34 +421,49 @@ def prepare_lora_batch(self, model_worker_batch: ModelWorkerBatch):
420421
421422 assert len (cur_uids ) <= self .max_loras_per_batch
422423
423- # Load adapters into device memory pool (CPU -> device transfer)
424- self .memory_pool .prepare_lora_batch (
425- cur_uids = cur_uids ,
426- lora_adapters = self .loras ,
427- )
428-
429424 weight_indices = [0 ] * len (model_worker_batch .lora_ids )
430425 lora_ranks = [0 ] * self .max_loras_per_batch
431426 scalings = [0 ] * self .max_loras_per_batch
432427
433- for i , uid in enumerate (model_worker_batch .lora_ids ):
434- weight_indices [i ] = self .memory_pool .get_buffer_id (uid )
435- if uid is not None and uid in self .loras :
436- lora = self .loras [uid ]
437- lora_ranks [weight_indices [i ]] = lora .config .r
438- scalings [weight_indices [i ]] = lora .scaling
439-
440- self .lora_backend .prepare_lora_batch (
441- model_worker_batch = model_worker_batch ,
442- weight_indices = weight_indices ,
443- lora_ranks = lora_ranks ,
444- scalings = scalings ,
445- )
428+ def prepare_static_lora_batch ():
429+ self .lora_backend .prepare_lora_batch (
430+ model_worker_batch = model_worker_batch ,
431+ weight_indices = [0 ] * len (model_worker_batch .lora_ids ),
432+ lora_ranks = [self .max_lora_rank ] * self .max_loras_per_batch ,
433+ scalings = [self .server_args .lora_scaling ] * self .max_loras_per_batch ,
434+ )
435+
436+ def prepare_dynamic_lora_batch ():
437+ # Load adapters into device memory pool (CPU -> device transfer)
438+ self .memory_pool .prepare_lora_batch (
439+ cur_uids = cur_uids ,
440+ lora_adapters = self .loras ,
441+ )
442+
443+ for i , uid in enumerate (model_worker_batch .lora_ids ):
444+ weight_indices [i ] = self .memory_pool .get_buffer_id (uid )
445+ if uid is not None and uid in self .loras :
446+ lora = self .loras [uid ]
447+ lora_ranks [weight_indices [i ]] = lora .config .r
448+ scalings [weight_indices [i ]] = lora .scaling
449+
450+ self .lora_backend .prepare_lora_batch (
451+ model_worker_batch = model_worker_batch ,
452+ weight_indices = weight_indices ,
453+ lora_ranks = lora_ranks ,
454+ scalings = scalings ,
455+ )
456+
457+ if self .static_lora :
458+ prepare_static_lora_batch ()
459+ else :
460+ prepare_dynamic_lora_batch ()
446461
447462 # Update LoRA layer buffer references after loading new weights
448463 # This is necessary because JAX arrays are immutable, and load_lora_weight_to_buffer
449464 # creates new arrays. We need to update the references in LoRALinear layers.
450- self .update_lora_info ()
465+ if not self .static_lora :
466+ self .update_lora_info ()
451467
452468 logger .debug ("Prepared LoRA batch: %d unique adapters" , len (cur_uids ))
453469
0 commit comments