Skip to content

Commit 074af9f

Browse files
authored
add --enable-static-lora and --lora-scaling (#497)
1 parent 3ccb667 commit 074af9f

File tree

12 files changed

+348
-80
lines changed

12 files changed

+348
-80
lines changed

python/sgl_jax/bench_one_batch.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,11 @@ def _run_forward_and_sample(model_runner, batch: ScheduleBatch, token_first_arg:
267267
)
268268

269269
model_worker_batch = batch.get_model_worker_batch(
270-
[token_first_arg], [bs_needed], [cache_loc_needed], page_size
270+
[token_first_arg],
271+
[bs_needed],
272+
[cache_loc_needed],
273+
page_size,
274+
False,
271275
)
272276

273277
# Prepare attention forward metadata (required by FlashAttention backend)

python/sgl_jax/srt/entrypoints/engine.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
import uvloop
2020
import zmq
2121
import zmq.asyncio
22+
from flax import nnx
23+
24+
from sgl_jax.srt.utils.common_utils import SUPPORTED_LORA_TARGET_MODULES
25+
from sgl_jax.utils import traverse_and_update
2226

2327
# ruff: noqa: E402
2428
# Fix a bug of Python threading
@@ -194,6 +198,21 @@ async def async_generate(
194198
else:
195199
return await generator.__anext__()
196200

201+
def apply_dummy_lora_ab_buffer(self, target_modules: list | None = None):
202+
if target_modules is None or len(target_modules) == 0:
203+
logger.warning("No %v is specified, so skip to apply", target_modules)
204+
return
205+
206+
if "all" in target_modules:
207+
target_modules = SUPPORTED_LORA_TARGET_MODULES
208+
209+
logger.info("Applying dummy LoRA buffers to modules: %v", target_modules)
210+
211+
model_runner = self.scheduler_info["scheduler"].tp_worker.worker.model_runner
212+
model_state = nnx.split(model_runner.model)[1]
213+
new_state = traverse_and_update(model_state, target_modules)
214+
nnx.update(model_runner.model, new_state)
215+
197216
def encode(
198217
self,
199218
prompt: str | list[str] | list[dict] | list[list[dict]],

python/sgl_jax/srt/lora/backend/bgmv_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def prepare_lora_batch(
233233
scalings_cpu, [0, num_to_pad], mode="constant", constant_values=0.0
234234
)
235235
padded_token_lora_indices_cpu = np.pad(
236-
token_lora_indices_cpu, [0, num_to_pad], mode="constant", constant_values=-1
236+
token_lora_indices_cpu, [0, num_to_pad], mode="constant", constant_values=0
237237
)
238238
padded_lora_ranks_cpu = np.pad(
239239
lora_ranks_cpu, [0, num_to_pad], mode="constant", constant_values=0

python/sgl_jax/srt/lora/lora_manager.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(
111111
self.num_attention_heads = base_hf_config.num_attention_heads
112112
self.num_kv_heads = getattr(base_hf_config, "num_key_value_heads", self.num_attention_heads)
113113
self.head_dim = getattr(base_hf_config, "head_dim", None)
114+
self.static_lora = server_args.enable_static_lora
114115

115116
# Get original num_kv_heads and tp_size for replication
116117
if model_config is not None:
@@ -420,34 +421,49 @@ def prepare_lora_batch(self, model_worker_batch: ModelWorkerBatch):
420421

421422
assert len(cur_uids) <= self.max_loras_per_batch
422423

423-
# Load adapters into device memory pool (CPU -> device transfer)
424-
self.memory_pool.prepare_lora_batch(
425-
cur_uids=cur_uids,
426-
lora_adapters=self.loras,
427-
)
428-
429424
weight_indices = [0] * len(model_worker_batch.lora_ids)
430425
lora_ranks = [0] * self.max_loras_per_batch
431426
scalings = [0] * self.max_loras_per_batch
432427

433-
for i, uid in enumerate(model_worker_batch.lora_ids):
434-
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
435-
if uid is not None and uid in self.loras:
436-
lora = self.loras[uid]
437-
lora_ranks[weight_indices[i]] = lora.config.r
438-
scalings[weight_indices[i]] = lora.scaling
439-
440-
self.lora_backend.prepare_lora_batch(
441-
model_worker_batch=model_worker_batch,
442-
weight_indices=weight_indices,
443-
lora_ranks=lora_ranks,
444-
scalings=scalings,
445-
)
428+
def prepare_static_lora_batch():
429+
self.lora_backend.prepare_lora_batch(
430+
model_worker_batch=model_worker_batch,
431+
weight_indices=[0] * len(model_worker_batch.lora_ids),
432+
lora_ranks=[self.max_lora_rank] * self.max_loras_per_batch,
433+
scalings=[self.server_args.lora_scaling] * self.max_loras_per_batch,
434+
)
435+
436+
def prepare_dynamic_lora_batch():
437+
# Load adapters into device memory pool (CPU -> device transfer)
438+
self.memory_pool.prepare_lora_batch(
439+
cur_uids=cur_uids,
440+
lora_adapters=self.loras,
441+
)
442+
443+
for i, uid in enumerate(model_worker_batch.lora_ids):
444+
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
445+
if uid is not None and uid in self.loras:
446+
lora = self.loras[uid]
447+
lora_ranks[weight_indices[i]] = lora.config.r
448+
scalings[weight_indices[i]] = lora.scaling
449+
450+
self.lora_backend.prepare_lora_batch(
451+
model_worker_batch=model_worker_batch,
452+
weight_indices=weight_indices,
453+
lora_ranks=lora_ranks,
454+
scalings=scalings,
455+
)
456+
457+
if self.static_lora:
458+
prepare_static_lora_batch()
459+
else:
460+
prepare_dynamic_lora_batch()
446461

447462
# Update LoRA layer buffer references after loading new weights
448463
# This is necessary because JAX arrays are immutable, and load_lora_weight_to_buffer
449464
# creates new arrays. We need to update the references in LoRALinear layers.
450-
self.update_lora_info()
465+
if not self.static_lora:
466+
self.update_lora_info()
451467

452468
logger.debug("Prepared LoRA batch: %d unique adapters", len(cur_uids))
453469

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,7 @@ def get_model_worker_batch(
11001100
bs_paddings: list,
11011101
cache_loc_paddings: list,
11021102
page_size: int,
1103+
enable_static_lora: bool = False,
11031104
skip_padding: bool = False,
11041105
) -> ModelWorkerBatch:
11051106
if skip_padding:
@@ -1371,7 +1372,11 @@ def get_model_worker_batch(
13711372
extend_seq_lens=(extend_seq_lens if self.forward_mode == ForwardMode.EXTEND else None),
13721373
extend_logprob_start_lens=extend_logprob_start_lens,
13731374
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1374-
lora_ids=lora_ids,
1375+
lora_ids=(
1376+
[req.lora_id for req in self.reqs] + [None] * bs_padding_size
1377+
if not enable_static_lora
1378+
else ["0"] * bs_paddings[select_bs_index]
1379+
),
13751380
real_bs=real_bs,
13761381
capture_hidden_mode=CaptureHiddenMode.NULL,
13771382
launch_done=self.launch_done,

python/sgl_jax/srt/managers/scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ def run_batch(self, batch: ScheduleBatch) -> GenerationBatchResult:
11891189
precompile_bs_paddings,
11901190
precompile_cache_loc_paddings,
11911191
self.page_size,
1192+
self.server_args.enable_static_lora,
11921193
)
11931194

11941195
if self.enable_overlap:
@@ -1229,6 +1230,7 @@ def run_batch(self, batch: ScheduleBatch) -> GenerationBatchResult:
12291230
precompile_bs_paddings,
12301231
precompile_cache_loc_paddings,
12311232
self.page_size,
1233+
self.server_args.enable_static_lora,
12321234
# eagle's model_worker_batch will be modified and repadding within eagle_worker
12331235
skip_padding=True,
12341236
)

python/sgl_jax/srt/managers/tp_worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def precompile_extend(self, future_token_ids_map=None):
236236
num_tokens,
237237
ForwardMode.EXTEND,
238238
self.precompile_cache_loc_paddings[-1],
239+
enable_static_lora=self.server_args.enable_static_lora,
239240
)
240241
# Prepare LoRA batch if LoRA is enabled
241242
if self.server_args.enable_lora:
@@ -278,6 +279,7 @@ def precompile_decode(self, future_token_ids_map=None):
278279
bs,
279280
ForwardMode.DECODE,
280281
aligned_cache_loc_size,
282+
enable_static_lora=self.server_args.enable_static_lora,
281283
)
282284
# Prepare LoRA batch if LoRA is enabled
283285
if self.server_args.enable_lora:
@@ -341,6 +343,7 @@ def generate_model_worker_batch(
341343
max_cache_loc_size: int,
342344
do_penalties: bool = False,
343345
speculative_algotithm=None,
346+
enable_static_lora: bool = None,
344347
) -> ModelWorkerBatch:
345348
valid_input_ids = np.array([1] * bs, dtype=jnp.int32)
346349
invalid_input_ids = np.array([0] * (num_tokens - bs), dtype=jnp.int32)
@@ -354,7 +357,7 @@ def generate_model_worker_batch(
354357

355358
valid_cache_loc = np.arange(bs)
356359
invalid_cache_loc = np.array([0] * (invalid_cache_loc_size), dtype=jnp.int32)
357-
lora_ids = [0] if bs == 1 else [0] * (bs // 2) + [None] * (bs - bs // 2)
360+
lora_ids = ["0"] if bs == 1 else ["0"] * (bs // 2) + [None] * (bs - bs // 2)
358361

359362
return ModelWorkerBatch(
360363
bid=1,
@@ -384,7 +387,7 @@ def generate_model_worker_batch(
384387
extend_logprob_start_lens=None,
385388
capture_hidden_mode=CaptureHiddenMode.NULL,
386389
spec_algorithm=speculative_algotithm,
387-
lora_ids=lora_ids,
390+
lora_ids=lora_ids if not enable_static_lora else ["0"] * bs,
388391
)
389392

390393
def get_model_runner(self):

0 commit comments

Comments
 (0)