Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7298039
support trace_agg_mode
JiahangXu Oct 9, 2025
fb08c5f
remove breakpoint and fix conner case
JiahangXu Oct 11, 2025
47f5186
Merge branch 'main' into dev/support_mask
JiahangXu Oct 11, 2025
dfb6323
reformatted daemon
JiahangXu Oct 11, 2025
0d2dece
Merge branch 'dev/support_mask' of github.com:microsoft/agent-lightni…
JiahangXu Oct 11, 2025
76fedd6
add fuzzy_startswith to support special_token_tolerance and string_to…
JiahangXu Oct 27, 2025
ee253db
Merge branch 'main' into dev/support_mask
JiahangXu Nov 4, 2025
a29f5f4
refactor to trace_aggregator
JiahangXu Nov 5, 2025
a77075e
Merge branch 'main' into dev/support_mask
JiahangXu Nov 5, 2025
816c8ed
fix typo
JiahangXu Nov 5, 2025
20ff8ba
add logs, fix mask mapping
JiahangXu Nov 5, 2025
d93c2cc
fix typo
JiahangXu Nov 5, 2025
672f037
fix pylint error
JiahangXu Nov 5, 2025
364d539
fix pylint error
JiahangXu Nov 5, 2025
c8a8b93
Update Search-R1 Example to v0.2.x
SiyunZhao Nov 5, 2025
db9d280
delete redundant script
SiyunZhao Nov 5, 2025
5d551ad
add response id error log, convert to gen response
JiahangXu Nov 10, 2025
f60a49b
fix path
SiyunZhao Nov 13, 2025
5e35898
delete redundant parameter
SiyunZhao Nov 14, 2025
77de9e9
Merge branch 'dev/search_r1_v02' into dev/support_mask
JiahangXu Nov 18, 2025
5c3b9c6
stage debug scripts
JiahangXu Nov 26, 2025
7b57eed
Merge branch 'main' into dev/search_r1_v02
JiahangXu Nov 28, 2025
11e8589
update logger
JiahangXu Nov 28, 2025
161961d
stage test scripts
JiahangXu Nov 28, 2025
283eb89
update daemon timeout
JiahangXu Nov 28, 2025
bcbc95f
Merge branch 'dev/search_r1_v02' into dev/support_mask
JiahangXu Nov 29, 2025
0837a9d
update unmerged logs
JiahangXu Nov 29, 2025
b5afe14
update trajectory scripts
JiahangXu Nov 29, 2025
6bba3dc
Merge branch 'main' into dev/support_mask
JiahangXu Dec 2, 2025
f8a47d9
update mismatch logs, update scripts
JiahangXu Dec 2, 2025
67dad08
update logs code
JiahangXu Dec 3, 2025
e1dbc3c
update running scripts
JiahangXu Dec 3, 2025
e5be217
update external store
JiahangXu Dec 3, 2025
d458b26
update external store (fix bug)
JiahangXu Dec 4, 2025
951b168
update trajectory-strict and trajectory-tolerant
JiahangXu Dec 4, 2025
5a5d54f
update training scrtpts with new args
JiahangXu Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions agentlightning/verl/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ actor_rollout_ref:
custom_async_server:
path: pkg://agentlightning.verl.async_server
name: PatchedvLLMServer
trace_agg_mode: transition # transition or trajectory
131 changes: 102 additions & 29 deletions agentlightning/verl/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
llm_proxy: LLMProxy | None = None,
store: LightningStore | None = None,
adapter: TraceTripletAdapter | None = None,
trace_agg_mode: Literal["transition", "trajectory"] = "transition",
):
self.mode = mode

Expand Down Expand Up @@ -179,6 +180,7 @@ def __init__(
self.pad_token_id = pad_token_id
self.tokenizer = tokenizer
self.reward_fillna_value = reward_fillna_value
self.trace_agg_mode = trace_agg_mode

# Internal State
self.backend_llm_server_addresses: List[str] = []
Expand Down Expand Up @@ -630,49 +632,119 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
reward_list: List[float] = []
data_id_list: List[str] = []
rollout_id_list: List[str] = []
turn_index_list: List[int] = []
turn_index_list: List[int] | List[List[int]] = []
is_drop_list: List[bool] = []
n_trunc_sample_because_of_response = 0

for rollout_id, sample_info in finished_id_to_sample_info.items():
for turn_index, trace in enumerate(sample_info["trace_list"]):
if self.trace_agg_mode == "transition":
for rollout_id, sample_info in finished_id_to_sample_info.items():
for turn_index, trace in enumerate(sample_info["trace_list"]):

reward_list.append(sample_info["reward"])
prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]
reward_list.append(sample_info["reward"])
prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"]

# Mark samples with prompts exceeding max_prompt_length to be dropped later
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[:max_prompt_length]
is_drop_list.append(True)
else:
is_drop_list.append(False)
# Mark samples with prompts exceeding max_prompt_length to be dropped later
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[:max_prompt_length]
is_drop_list.append(True)
else:
is_drop_list.append(False)

# Truncate responses that exceed max_response_length
if len(response_ids) > max_response_length:
response_ids = response_ids[:max_response_length]
n_trunc_sample_because_of_response += 1
# Truncate responses that exceed max_response_length
if len(response_ids) > max_response_length:
response_ids = response_ids[:max_response_length]
n_trunc_sample_because_of_response += 1

# Pad prompts to the left and responses to the right
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
prompt_ids, max_prompt_length, self.pad_token_id
)
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
response_ids, max_response_length, self.pad_token_id
)
# Pad prompts to the left and responses to the right
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
prompt_ids, max_prompt_length, self.pad_token_id
)
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
response_ids, max_response_length, self.pad_token_id
)

input_ids_list.append(one_input_ids)
input_attention_mask_list.append(one_input_attention_mask)
response_ids_list.append(one_response_ids)
response_attention_mask_list.append(one_response_attention_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
turn_index_list.append(turn_index)
input_ids_list.append(one_input_ids)
input_attention_mask_list.append(one_input_attention_mask)
response_ids_list.append(one_response_ids)
response_attention_mask_list.append(one_response_attention_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
turn_index_list.append(turn_index)

elif self.trace_agg_mode == "trajectory":
response_mask_list: List[List[int]] = []

for rollout_id, sample_info in finished_id_to_sample_info.items():
merged_trace_idx: List[List[int]] = []
current_merged_trace_idx: List[int] = []
current_context: List[int] = []
for turn_index, trace in enumerate(sample_info["trace_list"]):
if (trace["prompt_ids"] + trace["response_ids"])[:len(current_context)] == current_context:
current_context = trace["prompt_ids"] + trace["response_ids"]
current_merged_trace_idx.append(turn_index)
else:
# assert len(current_merged_trace_idx) > 0
merged_trace_idx.append(current_merged_trace_idx)
current_merged_trace_idx = [turn_index]
current_context = trace["prompt_ids"] + trace["response_ids"]
if current_merged_trace_idx not in merged_trace_idx:
merged_trace_idx.append(current_merged_trace_idx)

for current_merged_trace_idx in merged_trace_idx:
prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"]
response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"]
prompt_length = len(prompt_ids)
response_mask = [1] * len(response_ids)
for turn_index in current_merged_trace_idx[1:]:
trace = sample_info["trace_list"][turn_index]
new_prompt_length = len(trace["prompt_ids"]) - len(response_ids) - prompt_length
response_ids += trace["prompt_ids"][-new_prompt_length:]
response_ids += trace["response_ids"]
response_mask += [0] * new_prompt_length
response_mask += [1] * len(trace["response_ids"])

reward_list.append(sample_info["reward"])

# Mark samples with prompts exceeding max_prompt_length to be dropped later
if len(prompt_ids) > max_prompt_length:
prompt_ids = prompt_ids[:max_prompt_length]
is_drop_list.append(True)
else:
is_drop_list.append(False)

# Truncate responses that exceed max_response_length
if len(response_ids) > max_response_length:
response_ids = response_ids[:max_response_length]
n_trunc_sample_because_of_response += 1

# Pad prompts to the left and responses to the right
one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask(
prompt_ids, max_prompt_length, self.pad_token_id
)
one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask(
response_ids, max_response_length, self.pad_token_id
)
one_response_mask, _ = get_right_padded_ids_and_attention_mask(
response_mask, max_response_length, 0
)

input_ids_list.append(one_input_ids)
input_attention_mask_list.append(one_input_attention_mask)
response_ids_list.append(one_response_ids)
response_attention_mask_list.append(one_response_attention_mask)
response_mask_list.append(one_response_mask)
data_id_list.append(sample_info["data_id"])
rollout_id_list.append(rollout_id)
turn_index_list.append(current_merged_trace_idx)
else:
raise ValueError(f"Unknown trace_agg_mode: {self.trace_agg_mode}")

n_transition = len(input_ids_list)
batch_input_ids = torch.LongTensor(input_ids_list).to(device)
input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device)
batch_response_ids = torch.LongTensor(response_ids_list).to(device)
response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device)
response_mask = torch.LongTensor(response_mask_list).to(device) if self.trace_agg_mode == "trajectory" else None

# Concatenate prompts and responses to form the full sequence
batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1)
Expand Down Expand Up @@ -700,6 +772,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int,
"position_ids": position_ids,
"is_drop_mask": is_drop_mask,
"token_level_scores": token_level_scores.contiguous(),
**({"response_mask": response_mask} if self.trace_agg_mode == "trajectory" else {}),
},
batch_size=n_transition,
)
Expand Down
5 changes: 4 additions & 1 deletion agentlightning/verl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def _train_step(self, batch_dict: dict) -> dict:
# uid is used for algorithm like GRPO, should be aligned to data id
batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]

batch.batch["response_mask"] = compute_response_mask(batch)
breakpoint()
if "response_mask" not in batch.batch:
batch.batch["response_mask"] = compute_response_mask(batch)

# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
Expand Down Expand Up @@ -310,6 +312,7 @@ def fit(self):
store=self.store,
llm_proxy=self.llm_proxy,
adapter=self.adapter,
trace_agg_mode=self.config.actor_rollout_ref.rollout.trace_agg_mode,
)
self.agent_mode_daemon.start()

Expand Down
Loading