Skip to content

Commit 9839809

Browse files
committed
set lora
1 parent 3ccb667 commit 9839809

File tree

9 files changed

+319
-359
lines changed

9 files changed

+319
-359
lines changed

gemini.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Gemini CLI Constraints & Notes
2+
3+
- **Testing:** I cannot execute tests locally. Any required testing will be deferred to the end of the task for the user to execute manually. I will provide the necessary commands to run the tests.

python/sgl_jax/srt/lora/backend/bgmv_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,10 @@ def prepare_lora_batch(
249249
lora_ranks=jnp.array(padded_lora_ranks_cpu, dtype=jnp.int32),
250250
)
251251

252-
self.batch_info = BatchInfo(batch_info)
252+
return batch_info
253+
254+
def set_batch_info(self, batch_info: LoRABatchInfo):
255+
self.batch_info.value = batch_info
253256

254257

255258
def shrink(

python/sgl_jax/srt/lora/lora_manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def prepare_lora_batch(self, model_worker_batch: ModelWorkerBatch):
421421
assert len(cur_uids) <= self.max_loras_per_batch
422422

423423
# Load adapters into device memory pool (CPU -> device transfer)
424-
self.memory_pool.prepare_lora_batch(
424+
has_new_weights = self.memory_pool.prepare_lora_batch(
425425
cur_uids=cur_uids,
426426
lora_adapters=self.loras,
427427
)
@@ -437,20 +437,27 @@ def prepare_lora_batch(self, model_worker_batch: ModelWorkerBatch):
437437
lora_ranks[weight_indices[i]] = lora.config.r
438438
scalings[weight_indices[i]] = lora.scaling
439439

440-
self.lora_backend.prepare_lora_batch(
440+
batch_info = self.lora_backend.prepare_lora_batch(
441441
model_worker_batch=model_worker_batch,
442442
weight_indices=weight_indices,
443443
lora_ranks=lora_ranks,
444444
scalings=scalings,
445445
)
446+
model_worker_batch.lora_batch_info = batch_info
446447

447448
# Update LoRA layer buffer references after loading new weights
448449
# This is necessary because JAX arrays are immutable, and load_lora_weight_to_buffer
449450
# creates new arrays. We need to update the references in LoRALinear layers.
450-
self.update_lora_info()
451+
if has_new_weights:
452+
self.update_lora_info()
451453

452454
logger.debug("Prepared LoRA batch: %d unique adapters", len(cur_uids))
453455

456+
def set_batch_info(self, batch_info):
457+
"""Set batch info in backend."""
458+
if hasattr(self, "lora_backend"):
459+
self.lora_backend.set_batch_info(batch_info)
460+
454461
def get_buffer_id(self, lora_id: str | None) -> int:
455462
"""Get buffer slot ID for a given LoRA adapter ID."""
456463
return self.memory_pool.get_buffer_id(lora_id)

python/sgl_jax/srt/lora/lora_memory_pool.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def prepare_lora_batch(
364364
self,
365365
cur_uids: set[str | None],
366366
lora_adapters: dict[str | None, LoRAAdapter],
367-
):
367+
) -> bool:
368368
"""
369369
Prepare LoRA batch by loading adapters into buffer slots.
370370
@@ -374,6 +374,9 @@ def prepare_lora_batch(
374374
cur_uids: Set of lora_ids needed for current batch
375375
lora_adapters: Dict mapping lora_id to LoRAAdapter
376376
377+
Returns:
378+
bool: True if new weights were loaded (requires updating references), False otherwise.
379+
377380
Raises:
378381
ValueError: If no buffer slots available
379382
"""
@@ -389,6 +392,8 @@ def get_available_buffer_slot() -> int:
389392
self.max_loras_per_batch,
390393
)
391394

395+
has_new_weights = False
396+
392397
# Load each adapter that's not already loaded
393398
for uid in cur_uids:
394399
if uid not in self.uid_to_buffer_id:
@@ -397,10 +402,13 @@ def get_available_buffer_slot() -> int:
397402
self.load_lora_weight_to_buffer(uid, buffer_id, lora_adapter)
398403
self.uid_to_buffer_id[uid] = buffer_id
399404
self.buffer_id_to_uid[buffer_id] = uid
405+
has_new_weights = True
400406
logger.info("Loaded LoRA %s into buffer slot %d", uid, buffer_id)
401407
else:
402408
logger.debug("LoRA %s already in buffer slot %d", uid, self.uid_to_buffer_id[uid])
403409

410+
return has_new_weights
411+
404412
def load_lora_weight_to_buffer(
405413
self,
406414
uid: str | None,

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,8 @@ class ModelWorkerBatch:
16941694
# If set, the output of the batch contains the hidden states of the run.
16951695
capture_hidden_mode: CaptureHiddenMode = None
16961696

1697+
lora_batch_info: Any | None = None
1698+
16971699
tree_cache: BasePrefixCache = None
16981700

16991701
def padding_model_worker_batch(

python/sgl_jax/srt/managers/tp_worker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,14 @@ def forward_batch_generation(
461461
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
462462

463463
# Prepare LoRA batch if LoRA is enabled
464-
if self.worker.server_args.enable_lora and self.need_prepare_lora_batch:
465-
self.get_model_runner().lora_manager.prepare_lora_batch(model_worker_batch)
464+
if self.worker.server_args.enable_lora:
465+
if model_worker_batch.lora_batch_info is None:
466+
self.get_model_runner().lora_manager.prepare_lora_batch(model_worker_batch)
467+
468+
if model_worker_batch.lora_batch_info is not None:
469+
self.get_model_runner().lora_manager.set_batch_info(
470+
model_worker_batch.lora_batch_info
471+
)
466472

467473
if forward_metadata is None:
468474
forward_metadata = self.worker.model_runner.attn_backend.get_forward_metadata(

python/sgl_jax/srt/model_executor/model_runner.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def initialize(self):
162162

163163
def initialize_jit(self):
164164
model_def, model_state = nnx.split(self.model)
165-
_, model_state_def = jax.tree_util.tree_flatten(model_state)
165+
model_state_leaves, model_state_def = jax.tree_util.tree_flatten(model_state)
166166
sampler_def, sampler_state = nnx.split(self.sampler)
167167
sampler_state_leaves, sampler_state_def = jax.tree_util.tree_flatten(sampler_state)
168168

@@ -199,13 +199,17 @@ def run_model_wrapper(forward_batch, logits_metadata):
199199
token_to_kv_pool = self.token_to_kv_pool
200200

201201
# Re-capture model state to get the latest LoRA weights
202-
_, model_state = nnx.split(self.model)
203-
model_state_leaves, _ = jax.tree_util.tree_flatten(model_state)
202+
if self.server_args.enable_lora:
203+
# Re-capture model state to get the latest LoRA weights
204+
_, model_state = nnx.split(self.model)
205+
current_model_state_leaves, _ = jax.tree_util.tree_flatten(model_state)
206+
else:
207+
current_model_state_leaves = model_state_leaves
204208

205209
return jitted_run_model(
206210
model_def,
207211
model_state_def,
208-
model_state_leaves,
212+
current_model_state_leaves,
209213
forward_batch,
210214
token_to_kv_pool,
211215
logits_metadata,

0 commit comments

Comments
 (0)