From 7298039ef09a5c467ea25e5f83b0ca9022031ce1 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Thu, 9 Oct 2025 14:51:52 +0000 Subject: [PATCH 01/28] support trace_agg_mode --- agentlightning/verl/config.yaml | 1 + agentlightning/verl/daemon.py | 131 +++++++++++++++++++++++++------- agentlightning/verl/trainer.py | 5 +- 3 files changed, 107 insertions(+), 30 deletions(-) diff --git a/agentlightning/verl/config.yaml b/agentlightning/verl/config.yaml index 82b23ae34..89533132e 100644 --- a/agentlightning/verl/config.yaml +++ b/agentlightning/verl/config.yaml @@ -19,3 +19,4 @@ actor_rollout_ref: custom_async_server: path: pkg://agentlightning.verl.async_server name: PatchedvLLMServer + trace_agg_mode: transition # transition or trajectory diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index c693f9c48..e94cc5d81 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -139,6 +139,7 @@ def __init__( llm_proxy: LLMProxy | None = None, store: LightningStore | None = None, adapter: TraceTripletAdapter | None = None, + trace_agg_mode: Literal["transition", "trajectory"] = "transition", ): self.mode = mode @@ -179,6 +180,7 @@ def __init__( self.pad_token_id = pad_token_id self.tokenizer = tokenizer self.reward_fillna_value = reward_fillna_value + self.trace_agg_mode = trace_agg_mode # Internal State self.backend_llm_server_addresses: List[str] = [] @@ -630,49 +632,119 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, reward_list: List[float] = [] data_id_list: List[str] = [] rollout_id_list: List[str] = [] - turn_index_list: List[int] = [] + turn_index_list: List[int] | List[List[int]] = [] is_drop_list: List[bool] = [] n_trunc_sample_because_of_response = 0 - for rollout_id, sample_info in finished_id_to_sample_info.items(): - for turn_index, trace in enumerate(sample_info["trace_list"]): + if self.trace_agg_mode == "transition": + for rollout_id, sample_info in finished_id_to_sample_info.items(): + for turn_index, trace in enumerate(sample_info["trace_list"]): - reward_list.append(sample_info["reward"]) - prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] + reward_list.append(sample_info["reward"]) + prompt_ids, response_ids = trace["prompt_ids"], trace["response_ids"] - # Mark samples with prompts exceeding max_prompt_length to be dropped later - if len(prompt_ids) > max_prompt_length: - prompt_ids = prompt_ids[:max_prompt_length] - is_drop_list.append(True) - else: - is_drop_list.append(False) + # Mark samples with prompts exceeding max_prompt_length to be dropped later + if len(prompt_ids) > max_prompt_length: + prompt_ids = prompt_ids[:max_prompt_length] + is_drop_list.append(True) + else: + is_drop_list.append(False) - # Truncate responses that exceed max_response_length - if len(response_ids) > max_response_length: - response_ids = response_ids[:max_response_length] - n_trunc_sample_because_of_response += 1 + # Truncate responses that exceed max_response_length + if len(response_ids) > max_response_length: + response_ids = response_ids[:max_response_length] + n_trunc_sample_because_of_response += 1 - # Pad prompts to the left and responses to the right - one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask( - prompt_ids, max_prompt_length, self.pad_token_id - ) - one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask( - response_ids, max_response_length, self.pad_token_id - ) + # Pad prompts to the left and responses to the right + one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask( + prompt_ids, max_prompt_length, self.pad_token_id + ) + one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask( + response_ids, max_response_length, self.pad_token_id + ) - input_ids_list.append(one_input_ids) - input_attention_mask_list.append(one_input_attention_mask) - response_ids_list.append(one_response_ids) - response_attention_mask_list.append(one_response_attention_mask) - data_id_list.append(sample_info["data_id"]) - rollout_id_list.append(rollout_id) - turn_index_list.append(turn_index) + input_ids_list.append(one_input_ids) + input_attention_mask_list.append(one_input_attention_mask) + response_ids_list.append(one_response_ids) + response_attention_mask_list.append(one_response_attention_mask) + data_id_list.append(sample_info["data_id"]) + rollout_id_list.append(rollout_id) + turn_index_list.append(turn_index) + + elif self.trace_agg_mode == "trajectory": + response_mask_list: List[List[int]] = [] + + for rollout_id, sample_info in finished_id_to_sample_info.items(): + merged_trace_idx: List[List[int]] = [] + current_merged_trace_idx: List[int] = [] + current_context: List[int] = [] + for turn_index, trace in enumerate(sample_info["trace_list"]): + if (trace["prompt_ids"] + trace["response_ids"])[:len(current_context)] == current_context: + current_context = trace["prompt_ids"] + trace["response_ids"] + current_merged_trace_idx.append(turn_index) + else: + # assert len(current_merged_trace_idx) > 0 + merged_trace_idx.append(current_merged_trace_idx) + current_merged_trace_idx = [turn_index] + current_context = trace["prompt_ids"] + trace["response_ids"] + if current_merged_trace_idx not in merged_trace_idx: + merged_trace_idx.append(current_merged_trace_idx) + + for current_merged_trace_idx in merged_trace_idx: + prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"] + response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"] + prompt_length = len(prompt_ids) + response_mask = [1] * len(response_ids) + for turn_index in current_merged_trace_idx[1:]: + trace = sample_info["trace_list"][turn_index] + new_prompt_length = len(trace["prompt_ids"]) - len(response_ids) - prompt_length + response_ids += trace["prompt_ids"][-new_prompt_length:] + response_ids += trace["response_ids"] + response_mask += [0] * new_prompt_length + response_mask += [1] * len(trace["response_ids"]) + + reward_list.append(sample_info["reward"]) + + # Mark samples with prompts exceeding max_prompt_length to be dropped later + if len(prompt_ids) > max_prompt_length: + prompt_ids = prompt_ids[:max_prompt_length] + is_drop_list.append(True) + else: + is_drop_list.append(False) + + # Truncate responses that exceed max_response_length + if len(response_ids) > max_response_length: + response_ids = response_ids[:max_response_length] + n_trunc_sample_because_of_response += 1 + + # Pad prompts to the left and responses to the right + one_input_ids, one_input_attention_mask = get_left_padded_ids_and_attention_mask( + prompt_ids, max_prompt_length, self.pad_token_id + ) + one_response_ids, one_response_attention_mask = get_right_padded_ids_and_attention_mask( + response_ids, max_response_length, self.pad_token_id + ) + one_response_mask, _ = get_right_padded_ids_and_attention_mask( + response_mask, max_response_length, 0 + ) + + input_ids_list.append(one_input_ids) + input_attention_mask_list.append(one_input_attention_mask) + response_ids_list.append(one_response_ids) + response_attention_mask_list.append(one_response_attention_mask) + response_mask_list.append(one_response_mask) + data_id_list.append(sample_info["data_id"]) + rollout_id_list.append(rollout_id) + turn_index_list.append(current_merged_trace_idx) + else: + raise ValueError(f"Unknown trace_agg_mode: {self.trace_agg_mode}") n_transition = len(input_ids_list) batch_input_ids = torch.LongTensor(input_ids_list).to(device) input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device) batch_response_ids = torch.LongTensor(response_ids_list).to(device) response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device) + response_mask = torch.LongTensor(response_mask_list).to(device) if self.trace_agg_mode == "trajectory" else None # Concatenate prompts and responses to form the full sequence batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1) @@ -700,6 +772,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "position_ids": position_ids, "is_drop_mask": is_drop_mask, "token_level_scores": token_level_scores.contiguous(), + **({"response_mask": response_mask} if self.trace_agg_mode == "trajectory" else {}), }, batch_size=n_transition, ) diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 4d1459cf9..3cb724758 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -137,7 +137,9 @@ def _train_step(self, batch_dict: dict) -> dict: # uid is used for algorithm like GRPO, should be aligned to data id batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"] - batch.batch["response_mask"] = compute_response_mask(batch) + breakpoint() + if "response_mask" not in batch.batch: + batch.batch["response_mask"] = compute_response_mask(batch) # compute global_valid tokens batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() @@ -310,6 +312,7 @@ def fit(self): store=self.store, llm_proxy=self.llm_proxy, adapter=self.adapter, + trace_agg_mode=self.config.actor_rollout_ref.rollout.trace_agg_mode, ) self.agent_mode_daemon.start() From fb08c5f5cbc845c852fb8e9d66a0991a90d4b54d Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Sat, 11 Oct 2025 03:57:44 +0000 Subject: [PATCH 02/28] remove breakpoint and fix conner case --- agentlightning/verl/daemon.py | 1 + agentlightning/verl/trainer.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index e94cc5d81..30c2f2ef1 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -715,6 +715,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, # Truncate responses that exceed max_response_length if len(response_ids) > max_response_length: response_ids = response_ids[:max_response_length] + response_mask = response_mask[:max_response_length] n_trunc_sample_because_of_response += 1 # Pad prompts to the left and responses to the right diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 3cb724758..297d66656 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -137,7 +137,6 @@ def _train_step(self, batch_dict: dict) -> dict: # uid is used for algorithm like GRPO, should be aligned to data id batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"] - breakpoint() if "response_mask" not in batch.batch: batch.batch["response_mask"] = compute_response_mask(batch) From dfb6323cdd386479d62fc20936771004c80c7e37 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Sat, 11 Oct 2025 04:11:03 +0000 Subject: [PATCH 03/28] reformatted daemon --- agentlightning/verl/daemon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 30c2f2ef1..609eb3485 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -673,13 +673,13 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, elif self.trace_agg_mode == "trajectory": response_mask_list: List[List[int]] = [] - + for rollout_id, sample_info in finished_id_to_sample_info.items(): merged_trace_idx: List[List[int]] = [] current_merged_trace_idx: List[int] = [] current_context: List[int] = [] for turn_index, trace in enumerate(sample_info["trace_list"]): - if (trace["prompt_ids"] + trace["response_ids"])[:len(current_context)] == current_context: + if (trace["prompt_ids"] + trace["response_ids"])[: len(current_context)] == current_context: current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) else: @@ -687,7 +687,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, merged_trace_idx.append(current_merged_trace_idx) current_merged_trace_idx = [turn_index] current_context = trace["prompt_ids"] + trace["response_ids"] - if current_merged_trace_idx not in merged_trace_idx: + if current_merged_trace_idx not in merged_trace_idx: merged_trace_idx.append(current_merged_trace_idx) for current_merged_trace_idx in merged_trace_idx: From 76fedd6b7bc0654502adbe5a376f2f021285af5b Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Mon, 27 Oct 2025 05:38:58 +0000 Subject: [PATCH 04/28] add fuzzy_startswith to support special_token_tolerance and string_tolerance --- agentlightning/verl/daemon.py | 85 ++++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index fe455bede..cf2f6352d 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -25,6 +25,89 @@ configure_logger() +from transformers import AutoTokenizer +model_dir = '/mnt/teamdrive/RAG_RL/models/meta-llama/Llama-3.2-3B' +tok = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=True) +# def _decode(ids, skip_special_tokens=True): +# return tok.decode(ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False) + +def fuzzy_startswith(full_ids, prefix_ids, tokenizer, special_token_tolerance=0, string_tolerance=0): + def _special_token_sequence(ids): + return [id for id in ids if id in tokenizer.all_special_ids] + + def _decode(ids, skip_special_tokens=True): + return tokenizer.decode(ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False) + + if special_token_tolerance < 0 or string_tolerance < 0: + raise ValueError("tolerance must be non-negative") + + # First, handle special tokens + full_special_ids = _special_token_sequence(full_ids) + prefix_special_ids = _special_token_sequence(prefix_ids) + diff_count = sum(1 for a, b in zip(full_special_ids, prefix_special_ids) if a != b) + special_token_tolerance -= diff_count + if special_token_tolerance < 0: + return False + + # Next, handle string content + full_string = _decode(full_ids, skip_special_tokens=True) + prefix_string = _decode(prefix_ids, skip_special_tokens=True) + m = len(prefix_string) + n = len(full_string) + + if m == 0: return True # Empty B always matches (distance 0 to empty prefix) + if n == 0: return m <= string_tolerance # B non-empty but A empty: only match if we can delete all of B within tolerance + if string_tolerance == 0: return full_string.startswith(prefix_string) # exact match required + + # use DP to compute edit distance with banded optimization + min_j = max(0, m - string_tolerance) + max_j = min(n, m + string_tolerance) + if min_j > max_j: return False # no possible prefix length + + prev_start = max(0, 0 - string_tolerance) + prev_end = min(n, 0 + string_tolerance) + prev = [j for j in range(prev_start, prev_end + 1)] + + for j_idx, j in enumerate(range(prev_start, prev_end + 1)): + if min_j <= j <= max_j and prev[j_idx] <= string_tolerance: + return True + + for i in range(1, m + 1): + # valid j range for this row + start_j = max(0, i - string_tolerance) + end_j = min(n, i + string_tolerance) + cur_len = end_j - start_j + 1 + cur = [0] * cur_len + + for idx, j in enumerate(range(start_j, end_j + 1)): + del_cost = None + prev_start = max(0, (i - 1) - string_tolerance) + prev_end = min(n, (i - 1) + string_tolerance) + if prev_start <= j <= prev_end: + del_cost = prev[j - prev_start] + 1 + else: + del_cost = abs((i - 1) - j) + 1 # safe over-approximation + + ins_cost = None + if j - 1 >= start_j: + ins_cost = cur[idx - 1] + 1 + else: + ins_cost = abs(i - (j - 1)) + 1 + + sub_cost = None + if prev_start <= (j - 1) <= prev_end: + sub_cost = prev[(j - 1) - prev_start] + (0 if prefix_string[i - 1] == full_string[j - 1] else 1) + else: + sub_cost = abs((i - 1) - (j - 1)) + (0 if prefix_string[i - 1] == full_string[j - 1] else 1) + + cur[idx] = min(del_cost, ins_cost, sub_cost) + + for idx, j in enumerate(range(start_j, end_j + 1)): + if min_j <= j <= max_j and cur[idx] <= string_tolerance: + return True + prev = cur + return False + def get_left_padded_ids_and_attention_mask( ids: List[int], max_length: int, pad_token_id: int @@ -686,7 +769,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, current_merged_trace_idx: List[int] = [] current_context: List[int] = [] for turn_index, trace in enumerate(sample_info["trace_list"]): - if (trace["prompt_ids"] + trace["response_ids"])[: len(current_context)] == current_context: + if fuzzy_startswith(trace["prompt_ids"] + trace["response_ids"], current_context, tok, special_token_tolerance=5): current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) else: From a29f5f4b875b413600ab04fa924d1a413007ccec Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 5 Nov 2025 07:00:42 +0000 Subject: [PATCH 05/28] refactor to trace_aggregator --- agentlightning/verl/config.yaml | 5 ++++- agentlightning/verl/daemon.py | 30 ++++++++++++++---------------- agentlightning/verl/trainer.py | 2 +- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/agentlightning/verl/config.yaml b/agentlightning/verl/config.yaml index 89533132e..5c2355d45 100644 --- a/agentlightning/verl/config.yaml +++ b/agentlightning/verl/config.yaml @@ -19,4 +19,7 @@ actor_rollout_ref: custom_async_server: path: pkg://agentlightning.verl.async_server name: PatchedvLLMServer - trace_agg_mode: transition # transition or trajectory + trace_aggregator: + mode: transition # transition or trajectory + special_token_tolerance: 10 # only supported in trajectory mode, suggest to set as n_turns + string_tolerance: 20 # only supported in trajectory mode, suggest to set as n_turns * 2 diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index ba261f63b..fa18f774a 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -32,18 +32,10 @@ "get_right_padded_ids_and_attention_mask", ] -from transformers import AutoTokenizer -model_dir = '/mnt/teamdrive/RAG_RL/models/meta-llama/Llama-3.2-3B' -tok = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=True) -# def _decode(ids, skip_special_tokens=True): -# return tok.decode(ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False) def fuzzy_startswith(full_ids, prefix_ids, tokenizer, special_token_tolerance=0, string_tolerance=0): def _special_token_sequence(ids): return [id for id in ids if id in tokenizer.all_special_ids] - - def _decode(ids, skip_special_tokens=True): - return tokenizer.decode(ids, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=False) if special_token_tolerance < 0 or string_tolerance < 0: raise ValueError("tolerance must be non-negative") @@ -57,8 +49,12 @@ def _decode(ids, skip_special_tokens=True): return False # Next, handle string content - full_string = _decode(full_ids, skip_special_tokens=True) - prefix_string = _decode(prefix_ids, skip_special_tokens=True) + full_string = tokenizer.decode(full_ids, skip_special_tokens=True) + prefix_string = tokenizer.decode(prefix_ids, skip_special_tokens=True) + full_ids = tokenizer.encode(full_string) + prefix_ids = tokenizer.encode(prefix_string) + full_string = tokenizer.decode(full_ids, skip_special_tokens=True) + prefix_string = tokenizer.decode(prefix_ids, skip_special_tokens=True) m = len(prefix_string) n = len(full_string) @@ -229,7 +225,7 @@ def __init__( llm_proxy: LLMProxy | None = None, store: LightningStore | None = None, adapter: TraceToTripletBase | None = None, - trace_agg_mode: Literal["transition", "trajectory"] = "transition", + trace_aggregator: Optional[Dict[str, Any]] = None, ): self.mode = mode self.llm_timeout_seconds = llm_timeout_seconds @@ -270,7 +266,7 @@ def __init__( self.pad_token_id = pad_token_id self.tokenizer = tokenizer self.reward_fillna_value = reward_fillna_value - self.trace_agg_mode = trace_agg_mode + self.trace_aggregator = trace_aggregator # Internal State self.backend_llm_server_addresses: List[str] = [] @@ -774,7 +770,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, is_drop_list: List[bool] = [] n_trunc_sample_because_of_response = 0 - if self.trace_agg_mode == "transition": + if self.trace_aggregator.mode == "transition": for rollout_id, sample_info in finished_id_to_sample_info.items(): for turn_index, trace in enumerate(sample_info["trace_list"]): @@ -809,7 +805,8 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, rollout_id_list.append(rollout_id) turn_index_list.append(turn_index) - elif self.trace_agg_mode == "trajectory": + elif self.trace_aggregator.mode == "trajectory": + breakpoint() response_mask_list: List[List[int]] = [] for rollout_id, sample_info in finished_id_to_sample_info.items(): @@ -817,11 +814,12 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, current_merged_trace_idx: List[int] = [] current_context: List[int] = [] for turn_index, trace in enumerate(sample_info["trace_list"]): - if fuzzy_startswith(trace["prompt_ids"] + trace["response_ids"], current_context, tok, special_token_tolerance=5): + if fuzzy_startswith(trace["prompt_ids"] + trace["response_ids"], current_context, self.tokenizer, + special_token_tolerance=self.trace_aggregator.special_token_tolerance, + string_tolerance=self.trace_aggregator.string_tolerance): current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) else: - # assert len(current_merged_trace_idx) > 0 merged_trace_idx.append(current_merged_trace_idx) current_merged_trace_idx = [turn_index] current_context = trace["prompt_ids"] + trace["response_ids"] diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index d2f77a196..bc8fd016c 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -421,7 +421,7 @@ def fit(self): store=self.store, llm_proxy=self.llm_proxy, adapter=self.adapter, - trace_agg_mode=self.config.actor_rollout_ref.rollout.trace_agg_mode, + trace_aggregator=self.config.actor_rollout_ref.rollout.trace_aggregator, ) self.agent_mode_daemon.start() From 816c8ed17ae95fd3f72baa350ccdb26f63727967 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 5 Nov 2025 07:43:26 +0000 Subject: [PATCH 06/28] fix typo --- agentlightning/verl/daemon.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index fa18f774a..77a57e3fb 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -874,14 +874,14 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, rollout_id_list.append(rollout_id) turn_index_list.append(current_merged_trace_idx) else: - raise ValueError(f"Unknown trace_agg_mode: {self.trace_agg_mode}") + raise ValueError(f"Unknown trace_aggregator mode: {self.trace_aggregator.mode}") n_transition = len(input_ids_list) batch_input_ids = torch.LongTensor(input_ids_list).to(device) input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device) batch_response_ids = torch.LongTensor(response_ids_list).to(device) response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device) - response_mask = torch.LongTensor(response_mask_list).to(device) if self.trace_agg_mode == "trajectory" else None + response_mask = torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.mode == "trajectory" else None # Concatenate prompts and responses to form the full sequence batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1) @@ -909,7 +909,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "position_ids": position_ids, "is_drop_mask": is_drop_mask, "token_level_scores": token_level_scores.contiguous(), - **({"response_mask": response_mask} if self.trace_agg_mode == "trajectory" else {}), + **({"response_mask": response_mask} if self.trace_aggregator.mode == "trajectory" else {}), }, batch_size=n_transition, ) From 20ff8ba04b72bc10ff9349182350416aeacfd203 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 5 Nov 2025 09:06:40 +0000 Subject: [PATCH 07/28] add logs, fix mask mapping --- agentlightning/verl/daemon.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 77a57e3fb..dd274f7ed 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -806,14 +806,17 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, turn_index_list.append(turn_index) elif self.trace_aggregator.mode == "trajectory": - breakpoint() response_mask_list: List[List[int]] = [] + unmerged_count = 0 # only for debug for rollout_id, sample_info in finished_id_to_sample_info.items(): merged_trace_idx: List[List[int]] = [] current_merged_trace_idx: List[int] = [] current_context: List[int] = [] + turn_ids = [] # log data, only for debug testing for turn_index, trace in enumerate(sample_info["trace_list"]): + # log data, only for debug testing + turn_ids.append({"nxt_turn":trace["prompt_ids"][:] + trace["response_ids"][:], "cur":current_context[:]}) if fuzzy_startswith(trace["prompt_ids"] + trace["response_ids"], current_context, self.tokenizer, special_token_tolerance=self.trace_aggregator.special_token_tolerance, string_tolerance=self.trace_aggregator.string_tolerance): @@ -826,18 +829,35 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, if current_merged_trace_idx not in merged_trace_idx: merged_trace_idx.append(current_merged_trace_idx) + # log data, only for debug testing + if len(merged_trace_idx) > 1: + # import random + # if random.random() < 0.5: + unmerged_count += 1 + for turn_index, d in enumerate(turn_ids): + with open('bad_case_jiahang.log', 'w') as f: + print("-" * 20, file=f) + print(merged_trace_idx, file=f) + print('~' * 20, file=f) + print(turn_index, file=f) + print(d["nxt_turn"], file=f) + print(d["cur"], file=f) + for current_merged_trace_idx in merged_trace_idx: prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"] - response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"] + accum_response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"] prompt_length = len(prompt_ids) - response_mask = [1] * len(response_ids) + response_mask = [1] * len(accum_response_ids) for turn_index in current_merged_trace_idx[1:]: trace = sample_info["trace_list"][turn_index] - new_prompt_length = len(trace["prompt_ids"]) - len(response_ids) - prompt_length - response_ids += trace["prompt_ids"][-new_prompt_length:] - response_ids += trace["response_ids"] + new_prompt_length = len(trace["prompt_ids"]) - len(accum_response_ids) - prompt_length + accum_response_ids += trace["prompt_ids"][-new_prompt_length:] + accum_response_ids += trace["response_ids"] response_mask += [0] * new_prompt_length response_mask += [1] * len(trace["response_ids"]) + final_sample = sample_info["trace_list"][current_merged_trace_idx[-1]] + response_ids = final_sample["prompt_ids"][prompt_length:] + final_sample["response_ids"] + assert len(response_ids) == len(accum_response_ids) # only for debug testing reward_list.append(sample_info["reward"]) @@ -921,6 +941,8 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "training/n_rollouts_w_trace": len(finished_id_to_sample_info), "training/n_truncated_triplets": n_trunc_sample_because_of_response, "training/n_triplets": n_transition, + # log data, only for debug testing + **({"training/n_unmerged_turns": unmerged_count} if self.trace_aggregator.mode == "trajectory" else {}), } # Add non-tensor data for advantage calculation and logging From d93c2ccf0238f15b346938a0b08e58e45e858f7f Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 5 Nov 2025 09:38:20 +0000 Subject: [PATCH 08/28] fix typo --- agentlightning/verl/daemon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index dd274f7ed..5c0c2f23f 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -835,7 +835,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, # if random.random() < 0.5: unmerged_count += 1 for turn_index, d in enumerate(turn_ids): - with open('bad_case_jiahang.log', 'w') as f: + with open('bad_case_jiahang.log', 'a+') as f: print("-" * 20, file=f) print(merged_trace_idx, file=f) print('~' * 20, file=f) From 672f03775e7101264ce2892f9a30a96c44f59c18 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 5 Nov 2025 10:18:48 +0000 Subject: [PATCH 09/28] fix pylint error --- agentlightning/verl/daemon.py | 42 ++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 5c0c2f23f..2f2435743 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -58,14 +58,18 @@ def _special_token_sequence(ids): m = len(prefix_string) n = len(full_string) - if m == 0: return True # Empty B always matches (distance 0 to empty prefix) - if n == 0: return m <= string_tolerance # B non-empty but A empty: only match if we can delete all of B within tolerance - if string_tolerance == 0: return full_string.startswith(prefix_string) # exact match required + if m == 0: + return True # Empty B always matches (distance 0 to empty prefix) + if n == 0: + return m <= string_tolerance # B non-empty but A empty: only match if we can delete all of B within tolerance + if string_tolerance == 0: + return full_string.startswith(prefix_string) # exact match required # use DP to compute edit distance with banded optimization min_j = max(0, m - string_tolerance) max_j = min(n, m + string_tolerance) - if min_j > max_j: return False # no possible prefix length + if min_j > max_j: + return False # no possible prefix length prev_start = max(0, 0 - string_tolerance) prev_end = min(n, 0 + string_tolerance) @@ -807,19 +811,25 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, elif self.trace_aggregator.mode == "trajectory": response_mask_list: List[List[int]] = [] - unmerged_count = 0 # only for debug + unmerged_count = 0 # only for debug for rollout_id, sample_info in finished_id_to_sample_info.items(): merged_trace_idx: List[List[int]] = [] current_merged_trace_idx: List[int] = [] current_context: List[int] = [] - turn_ids = [] # log data, only for debug testing + turn_ids = [] # log data, only for debug testing for turn_index, trace in enumerate(sample_info["trace_list"]): # log data, only for debug testing - turn_ids.append({"nxt_turn":trace["prompt_ids"][:] + trace["response_ids"][:], "cur":current_context[:]}) - if fuzzy_startswith(trace["prompt_ids"] + trace["response_ids"], current_context, self.tokenizer, - special_token_tolerance=self.trace_aggregator.special_token_tolerance, - string_tolerance=self.trace_aggregator.string_tolerance): + turn_ids.append( + {"nxt_turn":trace["prompt_ids"][:] + trace["response_ids"][:], "cur":current_context[:]} + ) + if fuzzy_startswith( + trace["prompt_ids"] + trace["response_ids"], + current_context, + self.tokenizer, + special_token_tolerance=self.trace_aggregator.special_token_tolerance, + string_tolerance=self.trace_aggregator.string_tolerance, + ): current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) else: @@ -831,14 +841,12 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, # log data, only for debug testing if len(merged_trace_idx) > 1: - # import random - # if random.random() < 0.5: unmerged_count += 1 for turn_index, d in enumerate(turn_ids): - with open('bad_case_jiahang.log', 'a+') as f: + with open("bad_case_jiahang.log", "a+") as f: print("-" * 20, file=f) print(merged_trace_idx, file=f) - print('~' * 20, file=f) + print("~" * 20, file=f) print(turn_index, file=f) print(d["nxt_turn"], file=f) print(d["cur"], file=f) @@ -857,7 +865,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, response_mask += [1] * len(trace["response_ids"]) final_sample = sample_info["trace_list"][current_merged_trace_idx[-1]] response_ids = final_sample["prompt_ids"][prompt_length:] + final_sample["response_ids"] - assert len(response_ids) == len(accum_response_ids) # only for debug testing + assert len(response_ids) == len(accum_response_ids) # only for debug testing reward_list.append(sample_info["reward"]) @@ -901,7 +909,9 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, input_attention_mask = torch.LongTensor(input_attention_mask_list).to(device) batch_response_ids = torch.LongTensor(response_ids_list).to(device) response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device) - response_mask = torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.mode == "trajectory" else None + response_mask = ( + torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.mode == "trajectory" else None + ) # Concatenate prompts and responses to form the full sequence batch_seq = torch.cat([batch_input_ids, batch_response_ids], dim=-1) From 364d539b6520717d30b4ebbc7462cfd8db19957c Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 5 Nov 2025 10:21:36 +0000 Subject: [PATCH 10/28] fix pylint error --- agentlightning/verl/daemon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 2f2435743..dae55fc72 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -821,7 +821,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, for turn_index, trace in enumerate(sample_info["trace_list"]): # log data, only for debug testing turn_ids.append( - {"nxt_turn":trace["prompt_ids"][:] + trace["response_ids"][:], "cur":current_context[:]} + {"nxt_turn": trace["prompt_ids"][:] + trace["response_ids"][:], "cur": current_context[:]} ) if fuzzy_startswith( trace["prompt_ids"] + trace["response_ids"], From c8a8b9327bd09588979d0b66d952896e30da2331 Mon Sep 17 00:00:00 2001 From: SiyunZhao Date: Wed, 5 Nov 2025 10:23:32 +0000 Subject: [PATCH 11/28] Update Search-R1 Example to v0.2.x --- examples/search_r1/README.md | 19 +- examples/search_r1/qa_em.py | 2 +- examples/search_r1/search_r1_agent.py | 163 ++++++++++++------ examples/search_r1/train_search_r1_agent.py | 182 ++++++++++++++++++++ 4 files changed, 300 insertions(+), 66 deletions(-) create mode 100644 examples/search_r1/train_search_r1_agent.py diff --git a/examples/search_r1/README.md b/examples/search_r1/README.md index 27f431782..4313b8b87 100644 --- a/examples/search_r1/README.md +++ b/examples/search_r1/README.md @@ -2,7 +2,7 @@ ## Overview -This example implements **Search R1** within Agent Lightning. It also serves as a demonstration of a **framework-free agent training pipeline**, showing how to run end-to-end RL training without relying on specialized frameworks. **It's tested and compatible with Agent-lightning v0.1.x**. +This example implements **Search R1** within Agent Lightning. It also serves as a demonstration of a **framework-free agent training pipeline**, showing how to run end-to-end RL training without relying on specialized frameworks. **It's tested and compatible with Agent-lightning v0.2.x**. The example is designed to run on a single node with 8 GPUs, each having at least 40 GB of memory. @@ -14,7 +14,7 @@ The example is designed to run on a single node with 8 GPUs, each having at leas | `retrieval_launch.sh` | Launches the retrieval service backed by the processed corpus | | `retrieval_server.py` | FastAPI server that powers document retrieval during training | | `search_r1_agent.py` | Agent-Lightning rollout script implementing the Search-R1 workflow | -| `train.sh` | Starts the RL training server that coordinates GRPO optimization | +| `train_search_r1_agent.py` | RL training script that coordinates GRPO optimization | | `qa_em.py` | Exact-match evaluation utilities for validating model predictions | --- @@ -65,23 +65,14 @@ The retrieval server implementation is based on `search_r1/search/retrieval_serv > If you plan to use WandB for experiment tracking, set the environment variable > `WANDB_API_KEY` before starting Ray. -2. **Launch the Agent** - - ```bash - python search_r1_agent.py - ``` - - This script automatically launches **128 agent workers** by default. Each agent follows the Search-R1 workflow, retrieving information from the database and generating answers accordingly. - - -3. **Start the Training Server** +2. **Start the Training Server** In another terminal, run: ```bash - bash train.sh + python train_search_r1_agent.py llama ``` - This script starts the RL training server. + This script starts the RL training. Each agent follows the Search-R1 workflow, retrieving information from the database and generating answers accordingly. --- diff --git a/examples/search_r1/qa_em.py b/examples/search_r1/qa_em.py index 48617605f..dc917dd18 100644 --- a/examples/search_r1/qa_em.py +++ b/examples/search_r1/qa_em.py @@ -75,7 +75,7 @@ def extract_solution(solution_str: str) -> Optional[str]: matches = list(match_iter) # If there are 0 or exactly 1 matches, return None - if len(matches) <= 1: + if len(matches) == 0: return None # If there are 2 or more matches, return the last one diff --git a/examples/search_r1/search_r1_agent.py b/examples/search_r1/search_r1_agent.py index 446df60ad..224589f31 100644 --- a/examples/search_r1/search_r1_agent.py +++ b/examples/search_r1/search_r1_agent.py @@ -1,20 +1,26 @@ # Copyright (c) Microsoft. All rights reserved. +from __future__ import annotations + import os import re -from typing import Any, Dict, List, Optional, Tuple, TypedDict, cast +import shutil +import tempfile +import time +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, cast + import requests from openai import OpenAI from qa_em import compute_score_em -from agentlightning import LLM, LitAgent, NamedResources, Trainer, configure_logger, reward +import agentlightning as agl -configure_logger() +agl.configure_logger() -# Copied and adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/data_process/nq_search.py -INSTRUCTION_FORMAT = """Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: """ +logger = agl.configure_logger(name=__name__) +INSTRUCTION_FORMAT = """Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: """ class Document(TypedDict): contents: str @@ -24,8 +30,7 @@ class RetrievalItem(TypedDict): document: Document -@reward -async def eval(prediction: str, ground_truth: List[str]) -> float: +def eval(prediction: str, ground_truth: List[str]) -> float: reward_score = float(compute_score_em(prediction, ground_truth)) print(f"pred: {prediction} | {type(ground_truth)} gold_answer: {ground_truth} | res: {reward_score}") return reward_score @@ -93,75 +98,131 @@ def passages2string(retrieval_result: List[RetrievalItem]) -> str: def call_llm( llm_client: OpenAI, model_name: str, - content: str, + content: str = "", + messages: List[dict] = [], temperature: float = 1.0, max_tokens: int = 500, ) -> str: + if not len(messages): + messages=[{"role": "user", "content": content}] + print(messages) response = llm_client.chat.completions.create( model=model_name, - messages=[{"role": "user", "content": content}], + messages=messages, temperature=temperature, max_tokens=max_tokens, ) return response.choices[0].message.content or "" -class Searchr1Agent(LitAgent[Any]): - async def training_rollout_async( +class SearchR1Agent(agl.LitAgent[Dict[str, Any]]): + + def __init__( self, - task: Any, - resources: NamedResources, - rollout: Any, - temperature: float = 1.0, - ) -> Any: + val_temperature: Optional[float] = 0.0, + max_turns: int = 4, + ) -> None: + super().__init__() + self.val_temperature = val_temperature + self.data_dir = os.environ.get("VERL_SEARCHR1_DATA_DIR", "data") + self.max_turns = max_turns + + def rollout( + self, + task: Dict[str, Any], + resources: agl.NamedResources, + rollout: agl.Rollout, + ) -> float | None: prompt = INSTRUCTION_FORMAT + task["question"] answer_list: List[str] = cast(List[str], task["golden_answers"]) - llm: LLM = cast(LLM, resources.get("main_llm")) + rollout_id = rollout.rollout_id + logger.info(f"[Rollout {rollout_id}] Question: {task['question']}") + logger.info(f"[Rollout {rollout_id}] Ground Truth: {answer_list}") + + start_time = time.time() + llm: agl.LLM = cast(agl.LLM, resources["main_llm"]) client = OpenAI( - base_url=llm.endpoint, + base_url=llm.get_base_url(rollout_id, rollout.attempt.attempt_id), # type: ignore api_key=os.environ.get("OPENAI_API_KEY", "token-abc123"), ) - turn_id = 0 - finished_flag = False - rollout_content: str = "" - - while turn_id < 4 and not finished_flag: - turn_id += 1 - turn_response = call_llm( - client, llm.model, prompt + rollout_content, temperature=temperature, max_tokens=500 - ) - valid_turn_response = postprocess_response(turn_response) - turn_env_feedback = execute_response(valid_turn_response) - if len(turn_env_feedback) == 0: - finished_flag = True - print(f"TURN ID {turn_id} | RESP: {turn_response} | ENV FEEDBACK: {turn_env_feedback}") - rollout_content += turn_response + turn_env_feedback - - if not finished_flag: - turn_response = call_llm( - client, llm.model, prompt + rollout_content, temperature=temperature, max_tokens=500 + if rollout.mode == "train": + temperature = llm.sampling_parameters.get("temperature", 1.0) + else: + temperature = ( + self.val_temperature + if self.val_temperature is not None + else 0.0 ) - rollout_content += turn_response - print(f"LAST TURN GENERATE | RESP: {turn_response}") - reward_score = await eval(rollout_content, answer_list) # reward is tracked with the decorator - print( + turn_id = 0 + finished_flag = False + hist_messages: List[Dict[str, Any]] = [{"role": "user", "content": prompt}] + + try: + while turn_id < self.max_turns and not finished_flag: + turn_id += 1 + turn_response = call_llm( + client, llm.model, messages=hist_messages, temperature=temperature, max_tokens=500 + ) + valid_turn_response = postprocess_response(turn_response) + hist_messages.append({"role": "assistant", "content": valid_turn_response}) + turn_env_feedback = execute_response(valid_turn_response) + if len(turn_env_feedback) == 0: + finished_flag = True + else: + hist_messages.append({"role": "user", "content": turn_env_feedback}) + logger.info(f"TURN ID {turn_id} | RESP: {turn_response} | ENV FEEDBACK: {turn_env_feedback}") + + if not finished_flag: + turn_response = call_llm( + client, llm.model, messages=hist_messages, temperature=temperature, max_tokens=500 + ) + hist_messages.append({"role": "assistant", "content": turn_response}) + logger.info(f"LAST TURN GENERATE | RESP: {turn_response}") + + last_turn_response = [msg["content"] for msg in hist_messages if msg["role"] == "assistant"][-1] + except Exception as e: + logger.exception(f"[Rollout {rollout_id}] Error during rollout: {e}") + return None + + end_time_rollout = time.time() + reward_score = eval(last_turn_response, answer_list) + logger.info("[Rollout %s] Reward: %s", rollout_id, reward_score) + end_time_eval = time.time() + + logger.info("[Rollout %s] Time taken for rollout: %.2f seconds", rollout_id, end_time_rollout - start_time) + logger.info( + "[Rollout %s] Time taken for evaluation: %.2f seconds", rollout_id, end_time_eval - end_time_rollout + ) + logger.info( "question: {} answer: {} ground_truth: {} reward: {}".format( - task["question"], rollout_content, answer_list, reward_score + task["question"], last_turn_response, answer_list, reward_score ) ) return reward_score - async def validation_rollout_async( - self, - task: Any, - resources: NamedResources, - rollout: Any, - ) -> Any: - # Use the same resources; set temperature to 0.0 for deterministic validation. - return await self.training_rollout_async(task, resources, rollout, temperature=0.0) + +def debug_search_r1_agent(): + searchr1_dev_data_path = os.path.join(os.environ.get("VERL_SEARCHR1_DATA_DIR", "data"), "test.parquet") + if not os.path.exists(searchr1_dev_data_path): + raise FileNotFoundError(f"Search_R1 dev data file {searchr1_dev_data_path} does not exist.") + df = pd.read_parquet(searchr1_dev_data_path).head(10) # type: ignore + df = cast(List[Dict[str, Any]], df.to_dict(orient="records")) # type: ignore + print("Debug data:", df) + + trainer = agl.Trainer( + n_workers=1, + initial_resources={ + "main_llm": agl.LLM( + endpoint=os.environ["OPENAI_API_BASE"], + model="gpt-4.1-nano", + sampling_parameters={"temperature": 0.0}, + ) + }, + ) + trainer.dev(SearchR1Agent(), df) if __name__ == "__main__": - Trainer(n_workers=128).fit(Searchr1Agent(), "http://localhost:9999/") + debug_search_r1_agent() diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py new file mode 100644 index 000000000..303fdd13b --- /dev/null +++ b/examples/search_r1/train_search_r1_agent.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from __future__ import annotations + +import argparse +import os +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional + +import pandas as pd +from v2_agent import SearchR1Agent + +import agentlightning as agl + +RL_TRAINING_CONFIG: Dict[str, Any] = { + "algorithm": { + "adv_estimator": "grpo", + "use_kl_in_reward": False, + }, + "data": { + "train_files": "data/train.parquet", + "val_files": "data/test.parquet", + "train_batch_size": 512, + "max_prompt_length": 6000, + "max_response_length": 4096, + "truncation": "error", + }, + "actor_rollout_ref": { + "rollout": { + "tensor_model_parallel_size": 1, + "n": 5, + "log_prob_micro_batch_size_per_gpu": 4, + "multi_turn": {"format": "hermes"}, + "name": "vllm", + "gpu_memory_utilization": 0.5, + "engine_kwargs": { + "vllm": { + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + }, + }, + "actor": { + "ppo_mini_batch_size": 128, + "ppo_micro_batch_size_per_gpu": 4, + "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, + "use_kl_loss": True, + "kl_loss_type": "low_var_kl", + "kl_loss_coef": 0.001, + "entropy_coeff": 0, + "clip_ratio_low": 0.2, + "clip_ratio_high": 0.3, + "fsdp_config": { + "param_offload": True, + "optimizer_offload": True, + }, + }, + "ref": { + "log_prob_micro_batch_size_per_gpu": 4, + "fsdp_config": {"param_offload": True}, + }, + "model": { + "path": "Qwen/Qwen2.5-Coder-1.5B-Instruct", + "use_remove_padding": True, + "enable_gradient_checkpointing": True, + }, + }, + "trainer": { + "n_gpus_per_node": 8, + "val_before_train": True, + "critic_warmup": 0, + "logger": ["console", "wandb"], + "project_name": "AgentLightning", + "experiment_name": "searchr1", + "nnodes": 1, + "test_freq": 10, + "save_freq":10, + "total_epochs": 15, + "total_training_steps": 300, + "default_local_dir": "/mnt/teamdrive/RAG_RL/checkpoints/searchr1_checkpoints/test1102" + }, +} + + +def config_train_fast() -> Dict[str, Any]: + """A fast training run for CI testing purposes.""" + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + EXPERIMENT_NAME = f"searchr1_{timestamp}" + PROJECT_NAME = "AgentLightningCI" + + # Simulate writing to $GITHUB_OUTPUT if it’s set + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + f.write(f"project_name={PROJECT_NAME}\n") + f.write(f"run_name={EXPERIMENT_NAME}\n") + + print("Set environment variables:") + print(f"PROJECT_NAME={PROJECT_NAME}") + print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}") + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6 + config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + config["data"]["val_files"] = "data/test_dev.parquet" + config["trainer"]["total_epochs"] = 1 + config["trainer"]["total_training_steps"] = 1 + config["trainer"]["experiment_name"] = EXPERIMENT_NAME + config["trainer"]["project_name"] = PROJECT_NAME + config["trainer"]["test_freq"] = 1 + return config + + +def config_train_qwen() -> Dict[str, Any]: + """A configuration for training with Qwen-2.5B.""" + + config = deepcopy(RL_TRAINING_CONFIG) + return config + + +def config_train_llama() -> Dict[str, Any]: + """A configuration for training with LLaMA-3.2-1B-Instruct. + + You will need a `HF_TOKEN` set to run with this config. + """ + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" + config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" + config["actor_rollout_ref"]["model"]["path"] = "meta-llama/Llama-3.2-3B-Instruct" + return config + + +def train(config: Dict[str, Any], active_agent: Optional[str]) -> None: + + agent = SearchR1Agent() + algorithm = agl.VERL(config) + trainer = agl.Trainer(n_runners=32, algorithm=algorithm, adapter={"agent_match": active_agent}) + print("Adapter agent match acknowledged:", trainer.adapter.agent_match) # type: ignore + + train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore + val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore + trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore + + +def main() -> None: + """Main function to parse arguments and run training.""" + parser = argparse.ArgumentParser( + description="Train an Search-R1 agent using different model configurations" + ) + + parser.add_argument( + "config", + choices=["fast", "qwen", "llama"], + help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", + ) + + parser.add_argument( + "--active-agent", type=str, help="Override the active agent name (default: auto-generated based on config)" + ) + + args = parser.parse_args() + + # Get the appropriate configuration + config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} + + config = config_functions[args.config]() + + # Set active agent - use provided value or default based on config choice + active_agent = args.active_agent + + print(f"Starting training with '{args.config}' configuration...") + print(f"Active agent: {active_agent}") + + train(config, active_agent) + + +if __name__ == "__main__": + main() From db9d28075de9aecf00a50cc31e236fb04dea8a1f Mon Sep 17 00:00:00 2001 From: SiyunZhao Date: Wed, 5 Nov 2025 10:54:38 +0000 Subject: [PATCH 12/28] delete redundant script --- examples/search_r1/train.sh | 56 --------------------- examples/search_r1/train_search_r1_agent.py | 2 +- 2 files changed, 1 insertion(+), 57 deletions(-) delete mode 100755 examples/search_r1/train.sh diff --git a/examples/search_r1/train.sh b/examples/search_r1/train.sh deleted file mode 100755 index 178260d45..000000000 --- a/examples/search_r1/train.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/bin/bash - -set -e - -export N_GPUS=8 -export BASE_MODEL=meta-llama/Llama-3.2-3B -export ROLLOUT_TP_SIZE=1 -export DATA_DIR=data -export EXPERIMENT_NAME=searchr1 -export PROJECT_NAME=AgentLightning-searchr1 -echo "Starting training script..." - -python -m agentlightning.verl \ - algorithm.adv_estimator=grpo \ - data.train_files=${DATA_DIR}/train.parquet \ - data.val_files=${DATA_DIR}/test.parquet \ - actor_rollout_ref.rollout.tensor_model_parallel_size=${ROLLOUT_TP_SIZE} \ - trainer.n_gpus_per_node=${N_GPUS} \ - data.train_batch_size=512 \ - actor_rollout_ref.rollout.n=5 \ - actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.rollout.multi_turn.format=hermes \ - actor_rollout_ref.model.path=${BASE_MODEL} \ - data.max_prompt_length=4096 \ - data.max_response_length=4096 \ - data.truncation='error' \ - trainer.val_before_train=True \ - actor_rollout_ref.actor.optim.lr=1e-6 \ - actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.95 \ - actor_rollout_ref.model.use_remove_padding=True \ - actor_rollout_ref.actor.use_kl_loss=true \ - actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.kl_loss_coef=0.001 \ - actor_rollout_ref.actor.entropy_coeff=0 \ - actor_rollout_ref.actor.clip_ratio_low=0.2 \ - actor_rollout_ref.actor.clip_ratio_high=0.3 \ - actor_rollout_ref.model.enable_gradient_checkpointing=True \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ - actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ - actor_rollout_ref.ref.fsdp_config.param_offload=True \ - algorithm.use_kl_in_reward=False \ - trainer.critic_warmup=0 \ - trainer.logger=['console','wandb'] \ - trainer.default_local_dir=checkpoints/searchr1_checkpoints/$EXPERIMENT_NAME \ - trainer.project_name=${PROJECT_NAME} \ - trainer.experiment_name=${EXPERIMENT_NAME} \ - trainer.nnodes=1 \ - trainer.save_freq=10 \ - trainer.test_freq=20 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=300 diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index 303fdd13b..e47baf1d2 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -79,7 +79,7 @@ "save_freq":10, "total_epochs": 15, "total_training_steps": 300, - "default_local_dir": "/mnt/teamdrive/RAG_RL/checkpoints/searchr1_checkpoints/test1102" + "default_local_dir": "checkpoints/searchr1_checkpoints/" }, } From 5d551ada95d3461c517f8e6db40392a7504e14bd Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Mon, 10 Nov 2025 06:08:06 +0000 Subject: [PATCH 13/28] add response id error log, convert to gen response --- agentlightning/verl/daemon.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index dae55fc72..6f95c0c02 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -865,8 +865,13 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, response_mask += [1] * len(trace["response_ids"]) final_sample = sample_info["trace_list"][current_merged_trace_idx[-1]] response_ids = final_sample["prompt_ids"][prompt_length:] + final_sample["response_ids"] - assert len(response_ids) == len(accum_response_ids) # only for debug testing + if len(response_ids) != len(accum_response_ids): # only for debug testing + with open("bad_case_jiahang.log", "a+") as f: + print("-" * 10 + "response_ids NUM NOT MATCH" + "-" * 10, file=f) + print(response_ids, file=f) + print(accum_response_ids, file=f) + response_ids = accum_response_ids # convert to the generating response ids, only for debug testing reward_list.append(sample_info["reward"]) # Mark samples with prompts exceeding max_prompt_length to be dropped later From f60a49b971384780cad766b4366d6b2f42302f5f Mon Sep 17 00:00:00 2001 From: SiyunZhao Date: Thu, 13 Nov 2025 07:35:34 +0000 Subject: [PATCH 14/28] fix path --- examples/search_r1/train_search_r1_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index e47baf1d2..8ef239f98 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Optional import pandas as pd -from v2_agent import SearchR1Agent +from search_r1_agent import SearchR1Agent import agentlightning as agl From 5e35898f866ba34ff43c7e6784dd06d963ecb2d5 Mon Sep 17 00:00:00 2001 From: SiyunZhao Date: Fri, 14 Nov 2025 04:01:02 +0000 Subject: [PATCH 15/28] delete redundant parameter --- examples/search_r1/train_search_r1_agent.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index 8ef239f98..dbf973585 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -134,12 +134,11 @@ def config_train_llama() -> Dict[str, Any]: return config -def train(config: Dict[str, Any], active_agent: Optional[str]) -> None: +def train(config: Dict[str, Any]) -> None: agent = SearchR1Agent() algorithm = agl.VERL(config) - trainer = agl.Trainer(n_runners=32, algorithm=algorithm, adapter={"agent_match": active_agent}) - print("Adapter agent match acknowledged:", trainer.adapter.agent_match) # type: ignore + trainer = agl.Trainer(n_runners=32, algorithm=algorithm) train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore @@ -158,10 +157,6 @@ def main() -> None: help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", ) - parser.add_argument( - "--active-agent", type=str, help="Override the active agent name (default: auto-generated based on config)" - ) - args = parser.parse_args() # Get the appropriate configuration @@ -169,13 +164,9 @@ def main() -> None: config = config_functions[args.config]() - # Set active agent - use provided value or default based on config choice - active_agent = args.active_agent - print(f"Starting training with '{args.config}' configuration...") - print(f"Active agent: {active_agent}") - train(config, active_agent) + train(config) if __name__ == "__main__": From 5c3b9c6694b8e461f49cd4f6deee9b9d3a8c9c4b Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 26 Nov 2025 16:04:01 +0000 Subject: [PATCH 16/28] stage debug scripts --- agentlightning/verl/daemon.py | 2 +- examples/search_r1/train_search_r1_agent.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 873e23b67..a158e8f4d 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -436,7 +436,7 @@ def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[st raise RuntimeError("Internal loop is not running.") future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop) try: - future.result(timeout=60) # Wait for completion with a timeout + future.result(timeout=180) # Wait for completion with a timeout except Exception as e: print(f"Failed to set up data on server: {e}") raise diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index dbf973585..d03c3d53d 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -21,7 +21,7 @@ }, "data": { "train_files": "data/train.parquet", - "val_files": "data/test.parquet", + "val_files": "data/agent_test_50select.parquet", "train_batch_size": 512, "max_prompt_length": 6000, "max_response_length": 4096, @@ -62,7 +62,7 @@ "fsdp_config": {"param_offload": True}, }, "model": { - "path": "Qwen/Qwen2.5-Coder-1.5B-Instruct", + "path": "/home/aiscuser/.cache/huggingface/hub/models--meta-llama--Llama-3.2-3B/snapshots/13afe5124825b4f3751f836b40dafda64c1ed062", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, @@ -79,7 +79,7 @@ "save_freq":10, "total_epochs": 15, "total_training_steps": 300, - "default_local_dir": "checkpoints/searchr1_checkpoints/" + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/" }, } From 11e8589671acddc169cd788adc123ae9892d7ac3 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Fri, 28 Nov 2025 07:38:35 +0000 Subject: [PATCH 17/28] update logger --- examples/search_r1/search_r1_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/search_r1/search_r1_agent.py b/examples/search_r1/search_r1_agent.py index 4ae99f8d3..29b9361fe 100644 --- a/examples/search_r1/search_r1_agent.py +++ b/examples/search_r1/search_r1_agent.py @@ -14,9 +14,10 @@ from openai import OpenAI from qa_em import compute_score_em -from agentlightning import LLM, LitAgent, NamedResources, Rollout, Trainer, reward, setup_logging +from agentlightning import LLM, LitAgent, NamedResources, Rollout, Trainer, setup_logging, configure_logger setup_logging() +logger = configure_logger(name=__name__) INSTRUCTION_FORMAT = """Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: """ From 161961d72bd6f4e15571e8e12c718f080a64956c Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Fri, 28 Nov 2025 10:41:07 +0000 Subject: [PATCH 18/28] stage test scripts --- examples/search_r1/search_r1_agent.py | 1 + examples/search_r1/train_search_r1_agent.py | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/search_r1/search_r1_agent.py b/examples/search_r1/search_r1_agent.py index 29b9361fe..a68441aee 100644 --- a/examples/search_r1/search_r1_agent.py +++ b/examples/search_r1/search_r1_agent.py @@ -19,6 +19,7 @@ setup_logging() logger = configure_logger(name=__name__) +# Copied and adapted from https://github.com/PeterGriffinJin/Search-R1/blob/main/scripts/data_process/nq_search.py INSTRUCTION_FORMAT = """Answer the given question. You must conduct reasoning inside and first every time you get new information. After reasoning, if you find you lack some knowledge, you can call a search engine by query and it will return the top searched results between and . You can search as many times as your want. If you find no further external knowledge needed, you can directly provide the answer inside and , without detailed illustrations. For example, Beijing . Question: """ class Document(TypedDict): diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index d03c3d53d..8b5a5a2c9 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -43,7 +43,7 @@ }, }, "actor": { - "ppo_mini_batch_size": 128, + "ppo_mini_batch_size": 256, "ppo_micro_batch_size_per_gpu": 4, "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, "use_kl_loss": True, @@ -62,7 +62,7 @@ "fsdp_config": {"param_offload": True}, }, "model": { - "path": "/home/aiscuser/.cache/huggingface/hub/models--meta-llama--Llama-3.2-3B/snapshots/13afe5124825b4f3751f836b40dafda64c1ed062", + "path": "meta-llama/Llama-3.2-3B-Instruct", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, @@ -72,14 +72,14 @@ "val_before_train": True, "critic_warmup": 0, "logger": ["console", "wandb"], - "project_name": "AgentLightning", - "experiment_name": "searchr1", + "project_name": "AgentLightning-SearchR1", + "experiment_name": "searchr1_minibatch256_runner32", "nnodes": 1, "test_freq": 10, "save_freq":10, "total_epochs": 15, "total_training_steps": 300, - "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/" + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/searchr1_minibatch256_runner32/" }, } From 283eb89682e367bb5b0682684b74382c8a0b6419 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Fri, 28 Nov 2025 15:26:20 +0000 Subject: [PATCH 19/28] update daemon timeout --- agentlightning/verl/daemon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index d799dc568..c3b4c47bc 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -429,7 +429,7 @@ def set_up_data_and_server(self, data: Dict[str, Any], server_addresses: List[st raise RuntimeError("Internal loop is not running.") future = asyncio.run_coroutine_threadsafe(coro, self._internal_loop) try: - future.result(timeout=180) # Wait for completion with a timeout + future.result(timeout=3600) # Wait for completion with a timeout except Exception as e: print(f"Failed to set up data on server: {e}") raise From 0837a9df62669caba4ff286c671f6ca4feb4bf14 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Sat, 29 Nov 2025 15:14:43 +0000 Subject: [PATCH 20/28] update unmerged logs --- agentlightning/verl/config.yaml | 1 + agentlightning/verl/daemon.py | 13 +- agentlightning/verl/trainer.py | 5 +- .../search_r1/debug_train_search_r1_agent.py | 177 ++++++++++++++++++ examples/search_r1/search_r1_agent.py | 8 +- examples/search_r1/train_search_r1_agent.py | 6 + 6 files changed, 202 insertions(+), 8 deletions(-) create mode 100644 examples/search_r1/debug_train_search_r1_agent.py diff --git a/agentlightning/verl/config.yaml b/agentlightning/verl/config.yaml index 5c2355d45..6c0163ee1 100644 --- a/agentlightning/verl/config.yaml +++ b/agentlightning/verl/config.yaml @@ -23,3 +23,4 @@ actor_rollout_ref: mode: transition # transition or trajectory special_token_tolerance: 10 # only supported in trajectory mode, suggest to set as n_turns string_tolerance: 20 # only supported in trajectory mode, suggest to set as n_turns * 2 + trajectory_max_length: 8192 # only supported in trajectory mode, suggest to set as n_turns * (max_response_length + max_prompt_length) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index ba8a848a9..4fbf9b470 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -814,7 +814,8 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, elif self.trace_aggregator.mode == "trajectory": response_mask_list: List[List[int]] = [] - unmerged_count = 0 # only for debug + unmerged_count: int = 0 # only for debug + response_per_turn_list: List[int] = [] # only for debug for rollout_id, sample_info in finished_id_to_sample_info.items(): merged_trace_idx: List[List[int]] = [] @@ -826,6 +827,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, turn_ids.append( {"nxt_turn": trace["prompt_ids"][:] + trace["response_ids"][:], "cur": current_context[:]} ) + response_per_turn_list.append(len(trace["response_ids"])) if fuzzy_startswith( trace["prompt_ids"] + trace["response_ids"], current_context, @@ -961,13 +963,18 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "training/n_truncated_triplets": n_trunc_sample_because_of_response, "training/n_triplets": n_transition, # log data, only for debug testing - **({"training/n_unmerged_turns": unmerged_count} if self.trace_aggregator.mode == "trajectory" else {}), + **({ + "training/n_unmerged_turns": unmerged_count, + "training/avg_response_by_turn": np.mean(response_per_turn_list) if response_per_turn_list else 0, + "training/max_response_by_turn": np.max(response_per_turn_list) if response_per_turn_list else 0, + "training/min_response_by_turn": np.min(response_per_turn_list) if response_per_turn_list else 0, + } if self.trace_aggregator.mode == "trajectory" else {}), } # Add non-tensor data for advantage calculation and logging data_proto.non_tensor_batch["data_id_list"] = np.array(data_id_list) # type: ignore data_proto.non_tensor_batch["rollout_id_list"] = np.array(rollout_id_list) # type: ignore - data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list) # type: ignore + # data_proto.non_tensor_batch["turn_index_list"] = np.array(turn_index_list) # type: ignore return data_proto, data_metrics diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index d23ad9c75..984356928 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -217,9 +217,12 @@ def _train_step(self, batch_dict: dict) -> dict: gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses ) self.agent_mode_daemon.run_until_all_finished() + with _timer("gen_postprocess", timing_raw): batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( max_prompt_length=self.config.data.max_prompt_length, - max_response_length=self.config.data.max_response_length, + max_response_length=self.config.actor_rollout_ref.rollout.trace_aggregator.trajectory_max_length \ + if self.config.actor_rollout_ref.rollout.trace_aggregator.mode == "trajectory" else \ + self.config.data.max_response_length, device=gen_batch.batch["fake_ids"].device, ) metrics.update(agent_metrics) diff --git a/examples/search_r1/debug_train_search_r1_agent.py b/examples/search_r1/debug_train_search_r1_agent.py new file mode 100644 index 000000000..a986b1039 --- /dev/null +++ b/examples/search_r1/debug_train_search_r1_agent.py @@ -0,0 +1,177 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from __future__ import annotations + +import argparse +import os +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional + +import pandas as pd +from search_r1_agent import SearchR1Agent + +import agentlightning as agl + +RL_TRAINING_CONFIG: Dict[str, Any] = { + "algorithm": { + "adv_estimator": "grpo", + "use_kl_in_reward": False, + }, + "data": { + "train_files": "data/train.parquet", + "val_files": "data/test.parquet", + "train_batch_size": 2, + "max_prompt_length": 6000, + "max_response_length": 4096, + "truncation": "error", + }, + "actor_rollout_ref": { + "rollout": { + "tensor_model_parallel_size": 1, + "n": 1, + "log_prob_micro_batch_size_per_gpu": 1, + "multi_turn": {"format": "hermes"}, + "name": "vllm", + "gpu_memory_utilization": 0.5, + "engine_kwargs": { + "vllm": { + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + }, + "trace_aggregator": { + "mode": "trajectory", + "trajectory_max_length": 34384, + } + }, + "actor": { + "ppo_mini_batch_size": 1, + "ppo_micro_batch_size_per_gpu": 1, + "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, + "use_kl_loss": True, + "kl_loss_type": "low_var_kl", + "kl_loss_coef": 0.001, + "entropy_coeff": 0, + "clip_ratio_low": 0.2, + "clip_ratio_high": 0.3, + "fsdp_config": { + "param_offload": True, + "optimizer_offload": True, + }, + }, + "ref": { + "log_prob_micro_batch_size_per_gpu": 4, + "fsdp_config": {"param_offload": True}, + }, + "model": { + "path": "meta-llama/Llama-3.2-3B-Instruct", + "use_remove_padding": True, + "enable_gradient_checkpointing": True, + }, + }, + "trainer": { + "n_gpus_per_node": 1, + "val_before_train": False, + "critic_warmup": 0, + "logger": ["console", "wandb"], + "project_name": "AgentLightning-SearchR1", + "experiment_name": "searchr1_test", + "nnodes": 1, + "test_freq": 10, + "save_freq":10, + "total_epochs": 15, + "total_training_steps": 300, + "default_local_dir": "./test/" + }, +} + + +def config_train_fast() -> Dict[str, Any]: + """A fast training run for CI testing purposes.""" + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + EXPERIMENT_NAME = f"searchr1_{timestamp}" + PROJECT_NAME = "AgentLightningCI" + + # Simulate writing to $GITHUB_OUTPUT if it’s set + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + f.write(f"project_name={PROJECT_NAME}\n") + f.write(f"run_name={EXPERIMENT_NAME}\n") + + print("Set environment variables:") + print(f"PROJECT_NAME={PROJECT_NAME}") + print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}") + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6 + config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + config["data"]["val_files"] = "data/test_dev.parquet" + config["trainer"]["total_epochs"] = 1 + config["trainer"]["total_training_steps"] = 1 + config["trainer"]["experiment_name"] = EXPERIMENT_NAME + config["trainer"]["project_name"] = PROJECT_NAME + config["trainer"]["test_freq"] = 1 + return config + + +def config_train_qwen() -> Dict[str, Any]: + """A configuration for training with Qwen-2.5B.""" + + config = deepcopy(RL_TRAINING_CONFIG) + return config + + +def config_train_llama() -> Dict[str, Any]: + """A configuration for training with LLaMA-3.2-1B-Instruct. + + You will need a `HF_TOKEN` set to run with this config. + """ + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" + config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" + config["actor_rollout_ref"]["model"]["path"] = "meta-llama/Llama-3.2-3B-Instruct" + return config + + +def train(config: Dict[str, Any]) -> None: + + agent = SearchR1Agent() + algorithm = agl.VERL(config) + trainer = agl.Trainer(n_runners=32, algorithm=algorithm) + + train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore + val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore + trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore + + +def main() -> None: + """Main function to parse arguments and run training.""" + parser = argparse.ArgumentParser( + description="Train an Search-R1 agent using different model configurations" + ) + + parser.add_argument( + "config", + choices=["fast", "qwen", "llama"], + help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", + ) + + args = parser.parse_args() + + # Get the appropriate configuration + config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} + + config = config_functions[args.config]() + + print(f"Starting training with '{args.config}' configuration...") + + train(config) + + +if __name__ == "__main__": + main() diff --git a/examples/search_r1/search_r1_agent.py b/examples/search_r1/search_r1_agent.py index a68441aee..a444e51bb 100644 --- a/examples/search_r1/search_r1_agent.py +++ b/examples/search_r1/search_r1_agent.py @@ -38,10 +38,10 @@ def eval(prediction: str, ground_truth: List[str]) -> float: def postprocess_response(response: str) -> str: """Process responses to stop at search operation or answer operation.""" - if "" in response: - response = response.split("")[0] + "" - elif "" in response: - response = response.split("")[0] + "" + # if "" in response: + # response = response.split("")[0] + "" + # elif "" in response: + # response = response.split("")[0] + "" return response diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index 8b5a5a2c9..42d4ccdaf 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -41,6 +41,12 @@ "tool_call_parser": "hermes", } }, + "trace_aggregator": { + "mode": "trajectory", + "special_token_tolerance": 0, + "string_tolerance": 0, + "trajectory_max_length": 34384, + } }, "actor": { "ppo_mini_batch_size": 256, From b5afe14e82dc7add972e47971eac258171d54e40 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Sat, 29 Nov 2025 15:19:03 +0000 Subject: [PATCH 21/28] update trajectory scripts --- examples/search_r1/train_search_r1_agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index 42d4ccdaf..bd2e30c15 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -68,7 +68,7 @@ "fsdp_config": {"param_offload": True}, }, "model": { - "path": "meta-llama/Llama-3.2-3B-Instruct", + "path": "~/Llama-3.2-3B", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, @@ -79,13 +79,13 @@ "critic_warmup": 0, "logger": ["console", "wandb"], "project_name": "AgentLightning-SearchR1", - "experiment_name": "searchr1_minibatch256_runner32", + "experiment_name": "searchr1_minibatch256_runner32_trajectory", "nnodes": 1, "test_freq": 10, "save_freq":10, "total_epochs": 15, "total_training_steps": 300, - "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/searchr1_minibatch256_runner32/" + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/searchr1_minibatch256_runner32_trajectory/" }, } From f8a47d98224f8397c3ba91b415e5f5ff9ce8c1e2 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Tue, 2 Dec 2025 12:43:41 +0000 Subject: [PATCH 22/28] update mismatch logs, update scripts --- agentlightning/verl/daemon.py | 102 ++++++++-- agentlightning/verl/trainer.py | 14 ++ examples/search_r1/train_search_r1_agent.py | 2 +- .../train_search_r1_agent_transition.py | 176 ++++++++++++++++++ 4 files changed, 275 insertions(+), 19 deletions(-) create mode 100644 examples/search_r1/train_search_r1_agent_transition.py diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 4fbf9b470..5c343bec9 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -31,6 +31,71 @@ ] +def logged_startswith(full_ids, prefix_ids, tokenizer): + template_mismatch, retoken_mismatch, others_mismatch = False, False, False + if full_ids[:len(prefix_ids)] == prefix_ids: + merge = True + return template_mismatch, retoken_mismatch, others_mismatch, merge + else: + merge = False + + def _special_token_sequence(ids): + return [id for id in ids if id in tokenizer.all_special_ids] + + def _none_special_token_sequence(ids): + return [id for id in ids if id not in tokenizer.all_special_ids] + + # First, handle special tokens + full_special_ids = _special_token_sequence(full_ids) + prefix_special_ids = _special_token_sequence(prefix_ids) + diff_count = sum(1 for a, b in zip(full_special_ids, prefix_special_ids) if a != b) + if diff_count > 0: + template_mismatch = True + + # Next, handle string content + full_content_ids = _none_special_token_sequence(full_ids) + prefix_content_ids = _none_special_token_sequence(prefix_ids) + full_string = tokenizer.decode(full_ids, skip_special_tokens=True) + prefix_string = tokenizer.decode(prefix_ids, skip_special_tokens=True) + if full_content_ids[:len(prefix_content_ids)] != prefix_content_ids and full_string.startswith(prefix_string): + retoken_mismatch = True + elif full_content_ids[:len(prefix_content_ids)] != prefix_content_ids and not full_string.startswith(prefix_string): + others_mismatch = True + elif full_content_ids[:len(prefix_content_ids)] == prefix_content_ids: + # case 1: fully match; case 2: special token mismatch only + # case 1: template_mismatch == False, retoken_mismatch == False, others_mismatch == False, merge == True + # case 2: template_mismatch == True, retoken_mismatch == False, others_mismatch == False, merge == False + if (not template_mismatch and not retoken_mismatch and not others_mismatch and merge) \ + or (template_mismatch and not retoken_mismatch and not others_mismatch and not merge): + with open("bad_case_jiahang.log", "a+") as f: + print("-" * 20, file=f) + print("full_ids:", file=f) + print(full_ids, file=f) + print("prefix_ids:", file=f) + print(prefix_ids, file=f) + print(f"template_mismatch: {template_mismatch}, retoken_mismatch: {retoken_mismatch}, others_mismatch: {others_mismatch}, merge: {merge}", file=f) + return template_mismatch, retoken_mismatch, others_mismatch, merge + + +# log data, only for debug testing +def log_mismatch_detail(template_mismatch, retoken_mismatch, others_mismatch, full_ids, prefix_ids): + if template_mismatch: + with open("template_mismatch.log", "a+") as f: + print("-" * 20, file=f) + print(full_ids, file=f) + print(prefix_ids, file=f) + if retoken_mismatch: + with open("retoken_mismatch.log", "a+") as f: + print("-" * 20, file=f) + print(full_ids, file=f) + print(prefix_ids, file=f) + if others_mismatch: + with open("others_mismatch.log", "a+") as f: + print("-" * 20, file=f) + print(full_ids, file=f) + print(prefix_ids, file=f) + + def fuzzy_startswith(full_ids, prefix_ids, tokenizer, special_token_tolerance=0, string_tolerance=0): def _special_token_sequence(ids): return [id for id in ids if id in tokenizer.all_special_ids] @@ -815,6 +880,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, elif self.trace_aggregator.mode == "trajectory": response_mask_list: List[List[int]] = [] unmerged_count: int = 0 # only for debug + template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0 response_per_turn_list: List[int] = [] # only for debug for rollout_id, sample_info in finished_id_to_sample_info.items(): @@ -828,34 +894,31 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, {"nxt_turn": trace["prompt_ids"][:] + trace["response_ids"][:], "cur": current_context[:]} ) response_per_turn_list.append(len(trace["response_ids"])) - if fuzzy_startswith( + template_mismatch, retoken_mismatch, others_mismatch, merged = logged_startswith( trace["prompt_ids"] + trace["response_ids"], current_context, self.tokenizer, - special_token_tolerance=self.trace_aggregator.special_token_tolerance, - string_tolerance=self.trace_aggregator.string_tolerance, - ): + ) + template_mismatch_count += int(template_mismatch) + retoken_mismatch_count += int(retoken_mismatch) + others_mismatch_count += int(others_mismatch) + if merged: current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) else: + log_mismatch_detail( # log data, only for debug testing + template_mismatch, + retoken_mismatch, + others_mismatch, + trace["prompt_ids"] + trace["response_ids"], + current_context, + ) merged_trace_idx.append(current_merged_trace_idx) current_merged_trace_idx = [turn_index] current_context = trace["prompt_ids"] + trace["response_ids"] if current_merged_trace_idx not in merged_trace_idx: merged_trace_idx.append(current_merged_trace_idx) - # log data, only for debug testing - if len(merged_trace_idx) > 1: - unmerged_count += 1 - for turn_index, d in enumerate(turn_ids): - with open("bad_case_jiahang.log", "a+") as f: - print("-" * 20, file=f) - print(merged_trace_idx, file=f) - print("~" * 20, file=f) - print(turn_index, file=f) - print(d["nxt_turn"], file=f) - print(d["cur"], file=f) - for current_merged_trace_idx in merged_trace_idx: prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"] accum_response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"] @@ -871,8 +934,8 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, final_sample = sample_info["trace_list"][current_merged_trace_idx[-1]] response_ids = final_sample["prompt_ids"][prompt_length:] + final_sample["response_ids"] if len(response_ids) != len(accum_response_ids): # only for debug testing - with open("bad_case_jiahang.log", "a+") as f: - print("-" * 10 + "response_ids NUM NOT MATCH" + "-" * 10, file=f) + with open("response_ids_num_mismatch.log", "a+") as f: + print("-" * 20, file=f) print(response_ids, file=f) print(accum_response_ids, file=f) @@ -968,6 +1031,9 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "training/avg_response_by_turn": np.mean(response_per_turn_list) if response_per_turn_list else 0, "training/max_response_by_turn": np.max(response_per_turn_list) if response_per_turn_list else 0, "training/min_response_by_turn": np.min(response_per_turn_list) if response_per_turn_list else 0, + "training/template_mismatch_triplets": template_mismatch_count, + "training/retoken_mismatch_triplets": retoken_mismatch_count, + "training/others_mismatch_triplets": others_mismatch_count, } if self.trace_aggregator.mode == "trajectory" else {}), } diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index 984356928..c8fe75c66 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -156,6 +156,19 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True, suffix: str return metrics +def log_step_for_mismatch_detail(step: int) -> None: + with open("template_mismatch.log", "a+") as f: + print("-" * 10 + f" Step {step}" + "-" * 10, file=f) + with open("retoken_mismatch.log", "a+") as f: + print("-" * 10 + f" Step {step}" + "-" * 10, file=f) + with open("others_mismatch.log", "a+") as f: + print("-" * 10 + f" Step {step}" + "-" * 10, file=f) + with open("response_ids_num_mismatch.log", "a+") as f: + print("-" * 10 + f" Step {step}" + "-" * 10, file=f) + with open("bad_case_jiahang.log", "a+") as f: + print("-" * 10 + f" Step {step}" + "-" * 10, file=f) + + class AgentLightningTrainer(RayPPOTrainer): """ Specialized PPO trainer for agent-based reinforcement learning. @@ -454,6 +467,7 @@ def fit(self): for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: + log_step_for_mismatch_detail(self.global_steps) # log data, only for debug testing metrics = {} timing_raw = {} is_last_step = self.global_steps >= self.total_training_steps diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index bd2e30c15..c1d946d19 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -136,7 +136,7 @@ def config_train_llama() -> Dict[str, Any]: config = deepcopy(RL_TRAINING_CONFIG) config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" - config["actor_rollout_ref"]["model"]["path"] = "meta-llama/Llama-3.2-3B-Instruct" + config["actor_rollout_ref"]["model"]["path"] = "~/Llama-3.2-3B" return config diff --git a/examples/search_r1/train_search_r1_agent_transition.py b/examples/search_r1/train_search_r1_agent_transition.py new file mode 100644 index 000000000..a98086355 --- /dev/null +++ b/examples/search_r1/train_search_r1_agent_transition.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from __future__ import annotations + +import argparse +import os +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional + +import pandas as pd +from search_r1_agent import SearchR1Agent + +import agentlightning as agl + +RL_TRAINING_CONFIG: Dict[str, Any] = { + "algorithm": { + "adv_estimator": "grpo", + "use_kl_in_reward": False, + }, + "data": { + "train_files": "data/train.parquet", + "val_files": "data/agent_test_50select.parquet", + "train_batch_size": 512, + "max_prompt_length": 6000, + "max_response_length": 4096, + "truncation": "error", + }, + "actor_rollout_ref": { + "rollout": { + "tensor_model_parallel_size": 1, + "n": 5, + "log_prob_micro_batch_size_per_gpu": 4, + "multi_turn": {"format": "hermes"}, + "name": "vllm", + "gpu_memory_utilization": 0.5, + "engine_kwargs": { + "vllm": { + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + }, + "trace_aggregator": { + "mode": "transition", + } + }, + "actor": { + "ppo_mini_batch_size": 256, + "ppo_micro_batch_size_per_gpu": 4, + "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, + "use_kl_loss": True, + "kl_loss_type": "low_var_kl", + "kl_loss_coef": 0.001, + "entropy_coeff": 0, + "clip_ratio_low": 0.2, + "clip_ratio_high": 0.3, + "fsdp_config": { + "param_offload": True, + "optimizer_offload": True, + }, + }, + "ref": { + "log_prob_micro_batch_size_per_gpu": 4, + "fsdp_config": {"param_offload": True}, + }, + "model": { + "path": "~/Llama-3.2-3B", + "use_remove_padding": True, + "enable_gradient_checkpointing": True, + }, + }, + "trainer": { + "n_gpus_per_node": 8, + "val_before_train": True, + "critic_warmup": 0, + "logger": ["console", "wandb"], + "project_name": "AgentLightning-SearchR1", + "experiment_name": "searchr1_minibatch256_runner32_transition_synced", + "nnodes": 1, + "test_freq": 10, + "save_freq":10, + "total_epochs": 15, + "total_training_steps": 300, + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/searchr1_minibatch256_runner32_transition_synced/" + }, +} + + +def config_train_fast() -> Dict[str, Any]: + """A fast training run for CI testing purposes.""" + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + EXPERIMENT_NAME = f"searchr1_{timestamp}" + PROJECT_NAME = "AgentLightningCI" + + # Simulate writing to $GITHUB_OUTPUT if it’s set + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + f.write(f"project_name={PROJECT_NAME}\n") + f.write(f"run_name={EXPERIMENT_NAME}\n") + + print("Set environment variables:") + print(f"PROJECT_NAME={PROJECT_NAME}") + print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}") + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6 + config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + config["data"]["val_files"] = "data/test_dev.parquet" + config["trainer"]["total_epochs"] = 1 + config["trainer"]["total_training_steps"] = 1 + config["trainer"]["experiment_name"] = EXPERIMENT_NAME + config["trainer"]["project_name"] = PROJECT_NAME + config["trainer"]["test_freq"] = 1 + return config + + +def config_train_qwen() -> Dict[str, Any]: + """A configuration for training with Qwen-2.5B.""" + + config = deepcopy(RL_TRAINING_CONFIG) + return config + + +def config_train_llama() -> Dict[str, Any]: + """A configuration for training with LLaMA-3.2-1B-Instruct. + + You will need a `HF_TOKEN` set to run with this config. + """ + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" + config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" + config["actor_rollout_ref"]["model"]["path"] = "~/Llama-3.2-3B" + return config + + +def train(config: Dict[str, Any]) -> None: + + agent = SearchR1Agent() + algorithm = agl.VERL(config) + trainer = agl.Trainer(n_runners=32, algorithm=algorithm) + + train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore + val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore + trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore + + +def main() -> None: + """Main function to parse arguments and run training.""" + parser = argparse.ArgumentParser( + description="Train an Search-R1 agent using different model configurations" + ) + + parser.add_argument( + "config", + choices=["fast", "qwen", "llama"], + help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", + ) + + args = parser.parse_args() + + # Get the appropriate configuration + config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} + + config = config_functions[args.config]() + + print(f"Starting training with '{args.config}' configuration...") + + train(config) + + +if __name__ == "__main__": + main() From 67dad08a566a37f25a26ab1986074eae724a3dd6 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 3 Dec 2025 07:05:27 +0000 Subject: [PATCH 23/28] update logs code --- agentlightning/verl/daemon.py | 8 ++++++-- examples/search_r1/train_search_r1_agent.py | 4 ++-- examples/search_r1/train_search_r1_agent_transition.py | 4 ++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 5c343bec9..5682c2869 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -65,8 +65,8 @@ def _none_special_token_sequence(ids): # case 1: fully match; case 2: special token mismatch only # case 1: template_mismatch == False, retoken_mismatch == False, others_mismatch == False, merge == True # case 2: template_mismatch == True, retoken_mismatch == False, others_mismatch == False, merge == False - if (not template_mismatch and not retoken_mismatch and not others_mismatch and merge) \ - or (template_mismatch and not retoken_mismatch and not others_mismatch and not merge): + if not ((not template_mismatch and not retoken_mismatch and not others_mismatch and merge) \ + or (template_mismatch and not retoken_mismatch and not others_mismatch and not merge)): with open("bad_case_jiahang.log", "a+") as f: print("-" * 20, file=f) print("full_ids:", file=f) @@ -919,6 +919,10 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, if current_merged_trace_idx not in merged_trace_idx: merged_trace_idx.append(current_merged_trace_idx) + # log data, only for debug testing + if len(merged_trace_idx) > 1: + unmerged_count += 1 + for current_merged_trace_idx in merged_trace_idx: prompt_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["prompt_ids"] accum_response_ids = sample_info["trace_list"][current_merged_trace_idx[0]]["response_ids"] diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index c1d946d19..50f9cdf8c 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -68,7 +68,7 @@ "fsdp_config": {"param_offload": True}, }, "model": { - "path": "~/Llama-3.2-3B", + "path": "/home/aiscuser/Llama-3.2-3B", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, @@ -136,7 +136,7 @@ def config_train_llama() -> Dict[str, Any]: config = deepcopy(RL_TRAINING_CONFIG) config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" - config["actor_rollout_ref"]["model"]["path"] = "~/Llama-3.2-3B" + config["actor_rollout_ref"]["model"]["path"] = "/home/aiscuser/Llama-3.2-3B" return config diff --git a/examples/search_r1/train_search_r1_agent_transition.py b/examples/search_r1/train_search_r1_agent_transition.py index a98086355..52252286d 100644 --- a/examples/search_r1/train_search_r1_agent_transition.py +++ b/examples/search_r1/train_search_r1_agent_transition.py @@ -65,7 +65,7 @@ "fsdp_config": {"param_offload": True}, }, "model": { - "path": "~/Llama-3.2-3B", + "path": "/home/aiscuser/Llama-3.2-3B", "use_remove_padding": True, "enable_gradient_checkpointing": True, }, @@ -133,7 +133,7 @@ def config_train_llama() -> Dict[str, Any]: config = deepcopy(RL_TRAINING_CONFIG) config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" - config["actor_rollout_ref"]["model"]["path"] = "~/Llama-3.2-3B" + config["actor_rollout_ref"]["model"]["path"] = "/home/aiscuser/Llama-3.2-3B" return config From e1dbc3cce6d4d22f90adeaf372521975244d4e7a Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 3 Dec 2025 09:35:35 +0000 Subject: [PATCH 24/28] update running scripts --- examples/search_r1/train_search_r1_agent.py | 6 +- .../search_r1/train_search_r1_agent_ins.py | 179 ++++++++++++++++++ .../train_search_r1_agent_transition.py | 4 +- .../train_search_r1_agent_transition_ins.py | 176 +++++++++++++++++ 4 files changed, 360 insertions(+), 5 deletions(-) create mode 100644 examples/search_r1/train_search_r1_agent_ins.py create mode 100644 examples/search_r1/train_search_r1_agent_transition_ins.py diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent.py index 50f9cdf8c..ef50b3249 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent.py @@ -78,14 +78,14 @@ "val_before_train": True, "critic_warmup": 0, "logger": ["console", "wandb"], - "project_name": "AgentLightning-SearchR1", - "experiment_name": "searchr1_minibatch256_runner32_trajectory", + "project_name": "AgentLightning-SearchR1-Base", + "experiment_name": "searchr1_minibatch256_runner32_trajectory_synced", "nnodes": 1, "test_freq": 10, "save_freq":10, "total_epochs": 15, "total_training_steps": 300, - "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/searchr1_minibatch256_runner32_trajectory/" + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/Llama-3.2-3B/searchr1_minibatch256_runner32_trajectory_synced/" }, } diff --git a/examples/search_r1/train_search_r1_agent_ins.py b/examples/search_r1/train_search_r1_agent_ins.py new file mode 100644 index 000000000..e43eea0a1 --- /dev/null +++ b/examples/search_r1/train_search_r1_agent_ins.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from __future__ import annotations + +import argparse +import os +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional + +import pandas as pd +from search_r1_agent import SearchR1Agent + +import agentlightning as agl + +RL_TRAINING_CONFIG: Dict[str, Any] = { + "algorithm": { + "adv_estimator": "grpo", + "use_kl_in_reward": False, + }, + "data": { + "train_files": "data/train.parquet", + "val_files": "data/agent_test_50select.parquet", + "train_batch_size": 512, + "max_prompt_length": 6000, + "max_response_length": 4096, + "truncation": "error", + }, + "actor_rollout_ref": { + "rollout": { + "tensor_model_parallel_size": 1, + "n": 5, + "log_prob_micro_batch_size_per_gpu": 4, + "multi_turn": {"format": "hermes"}, + "name": "vllm", + "gpu_memory_utilization": 0.5, + "engine_kwargs": { + "vllm": { + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + }, + "trace_aggregator": { + "mode": "trajectory", + "special_token_tolerance": 0, + "string_tolerance": 0, + "trajectory_max_length": 34384, + } + }, + "actor": { + "ppo_mini_batch_size": 256, + "ppo_micro_batch_size_per_gpu": 4, + "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, + "use_kl_loss": True, + "kl_loss_type": "low_var_kl", + "kl_loss_coef": 0.001, + "entropy_coeff": 0, + "clip_ratio_low": 0.2, + "clip_ratio_high": 0.3, + "fsdp_config": { + "param_offload": True, + "optimizer_offload": True, + }, + }, + "ref": { + "log_prob_micro_batch_size_per_gpu": 4, + "fsdp_config": {"param_offload": True}, + }, + "model": { + "path": "meta-llama/Llama-3.2-3B-Instruct", + "use_remove_padding": True, + "enable_gradient_checkpointing": True, + }, + }, + "trainer": { + "n_gpus_per_node": 8, + "val_before_train": True, + "critic_warmup": 0, + "logger": ["console", "wandb"], + "project_name": "AgentLightning-SearchR1", + "experiment_name": "searchr1_minibatch256_runner32_trajectory_synced", + "nnodes": 1, + "test_freq": 10, + "save_freq":10, + "total_epochs": 15, + "total_training_steps": 300, + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/Llama-3.2-3B-Instruct/searchr1_minibatch256_runner32_trajectory_synced/" + }, +} + + +def config_train_fast() -> Dict[str, Any]: + """A fast training run for CI testing purposes.""" + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + EXPERIMENT_NAME = f"searchr1_{timestamp}" + PROJECT_NAME = "AgentLightningCI" + + # Simulate writing to $GITHUB_OUTPUT if it’s set + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + f.write(f"project_name={PROJECT_NAME}\n") + f.write(f"run_name={EXPERIMENT_NAME}\n") + + print("Set environment variables:") + print(f"PROJECT_NAME={PROJECT_NAME}") + print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}") + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6 + config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + config["data"]["val_files"] = "data/test_dev.parquet" + config["trainer"]["total_epochs"] = 1 + config["trainer"]["total_training_steps"] = 1 + config["trainer"]["experiment_name"] = EXPERIMENT_NAME + config["trainer"]["project_name"] = PROJECT_NAME + config["trainer"]["test_freq"] = 1 + return config + + +def config_train_qwen() -> Dict[str, Any]: + """A configuration for training with Qwen-2.5B.""" + + config = deepcopy(RL_TRAINING_CONFIG) + return config + + +def config_train_llama() -> Dict[str, Any]: + """A configuration for training with LLaMA-3.2-1B-Instruct. + + You will need a `HF_TOKEN` set to run with this config. + """ + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" + config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" + config["actor_rollout_ref"]["model"]["path"] = "meta-llama/Llama-3.2-3B-Instruct" + return config + + +def train(config: Dict[str, Any]) -> None: + + agent = SearchR1Agent() + algorithm = agl.VERL(config) + trainer = agl.Trainer(n_runners=32, algorithm=algorithm) + + train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore + val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore + trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore + + +def main() -> None: + """Main function to parse arguments and run training.""" + parser = argparse.ArgumentParser( + description="Train an Search-R1 agent using different model configurations" + ) + + parser.add_argument( + "config", + choices=["fast", "qwen", "llama"], + help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", + ) + + args = parser.parse_args() + + # Get the appropriate configuration + config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} + + config = config_functions[args.config]() + + print(f"Starting training with '{args.config}' configuration...") + + train(config) + + +if __name__ == "__main__": + main() diff --git a/examples/search_r1/train_search_r1_agent_transition.py b/examples/search_r1/train_search_r1_agent_transition.py index 52252286d..07e13f752 100644 --- a/examples/search_r1/train_search_r1_agent_transition.py +++ b/examples/search_r1/train_search_r1_agent_transition.py @@ -75,14 +75,14 @@ "val_before_train": True, "critic_warmup": 0, "logger": ["console", "wandb"], - "project_name": "AgentLightning-SearchR1", + "project_name": "AgentLightning-SearchR1-Base", "experiment_name": "searchr1_minibatch256_runner32_transition_synced", "nnodes": 1, "test_freq": 10, "save_freq":10, "total_epochs": 15, "total_training_steps": 300, - "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/searchr1_minibatch256_runner32_transition_synced/" + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/Llama-3.2-3B/searchr1_minibatch256_runner32_transition_synced/" }, } diff --git a/examples/search_r1/train_search_r1_agent_transition_ins.py b/examples/search_r1/train_search_r1_agent_transition_ins.py new file mode 100644 index 000000000..7f74e5488 --- /dev/null +++ b/examples/search_r1/train_search_r1_agent_transition_ins.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from __future__ import annotations + +import argparse +import os +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional + +import pandas as pd +from search_r1_agent import SearchR1Agent + +import agentlightning as agl + +RL_TRAINING_CONFIG: Dict[str, Any] = { + "algorithm": { + "adv_estimator": "grpo", + "use_kl_in_reward": False, + }, + "data": { + "train_files": "data/train.parquet", + "val_files": "data/agent_test_50select.parquet", + "train_batch_size": 512, + "max_prompt_length": 6000, + "max_response_length": 4096, + "truncation": "error", + }, + "actor_rollout_ref": { + "rollout": { + "tensor_model_parallel_size": 1, + "n": 5, + "log_prob_micro_batch_size_per_gpu": 4, + "multi_turn": {"format": "hermes"}, + "name": "vllm", + "gpu_memory_utilization": 0.5, + "engine_kwargs": { + "vllm": { + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + }, + "trace_aggregator": { + "mode": "transition", + } + }, + "actor": { + "ppo_mini_batch_size": 256, + "ppo_micro_batch_size_per_gpu": 4, + "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, + "use_kl_loss": True, + "kl_loss_type": "low_var_kl", + "kl_loss_coef": 0.001, + "entropy_coeff": 0, + "clip_ratio_low": 0.2, + "clip_ratio_high": 0.3, + "fsdp_config": { + "param_offload": True, + "optimizer_offload": True, + }, + }, + "ref": { + "log_prob_micro_batch_size_per_gpu": 4, + "fsdp_config": {"param_offload": True}, + }, + "model": { + "path": "meta-llama/Llama-3.2-3B-Instruct", + "use_remove_padding": True, + "enable_gradient_checkpointing": True, + }, + }, + "trainer": { + "n_gpus_per_node": 8, + "val_before_train": True, + "critic_warmup": 0, + "logger": ["console", "wandb"], + "project_name": "AgentLightning-SearchR1", + "experiment_name": "searchr1_minibatch256_runner32_transition_synced", + "nnodes": 1, + "test_freq": 10, + "save_freq":10, + "total_epochs": 15, + "total_training_steps": 300, + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/Llama-3.2-3B-Instruct/searchr1_minibatch256_runner32_transition_synced/" + }, +} + + +def config_train_fast() -> Dict[str, Any]: + """A fast training run for CI testing purposes.""" + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + EXPERIMENT_NAME = f"searchr1_{timestamp}" + PROJECT_NAME = "AgentLightningCI" + + # Simulate writing to $GITHUB_OUTPUT if it’s set + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + f.write(f"project_name={PROJECT_NAME}\n") + f.write(f"run_name={EXPERIMENT_NAME}\n") + + print("Set environment variables:") + print(f"PROJECT_NAME={PROJECT_NAME}") + print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}") + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6 + config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + config["data"]["val_files"] = "data/test_dev.parquet" + config["trainer"]["total_epochs"] = 1 + config["trainer"]["total_training_steps"] = 1 + config["trainer"]["experiment_name"] = EXPERIMENT_NAME + config["trainer"]["project_name"] = PROJECT_NAME + config["trainer"]["test_freq"] = 1 + return config + + +def config_train_qwen() -> Dict[str, Any]: + """A configuration for training with Qwen-2.5B.""" + + config = deepcopy(RL_TRAINING_CONFIG) + return config + + +def config_train_llama() -> Dict[str, Any]: + """A configuration for training with LLaMA-3.2-1B-Instruct. + + You will need a `HF_TOKEN` set to run with this config. + """ + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" + config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" + config["actor_rollout_ref"]["model"]["path"] = "meta-llama/Llama-3.2-3B-Instruct" + return config + + +def train(config: Dict[str, Any]) -> None: + + agent = SearchR1Agent() + algorithm = agl.VERL(config) + trainer = agl.Trainer(n_runners=32, algorithm=algorithm) + + train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore + val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore + trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore + + +def main() -> None: + """Main function to parse arguments and run training.""" + parser = argparse.ArgumentParser( + description="Train an Search-R1 agent using different model configurations" + ) + + parser.add_argument( + "config", + choices=["fast", "qwen", "llama"], + help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", + ) + + args = parser.parse_args() + + # Get the appropriate configuration + config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} + + config = config_functions[args.config]() + + print(f"Starting training with '{args.config}' configuration...") + + train(config) + + +if __name__ == "__main__": + main() From e5be217c4994e806722c99cd4c9da61254ddc68c Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Wed, 3 Dec 2025 10:46:56 +0000 Subject: [PATCH 25/28] update external store --- examples/search_r1/train_search_r1_agent_transition_ins.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/search_r1/train_search_r1_agent_transition_ins.py b/examples/search_r1/train_search_r1_agent_transition_ins.py index 7f74e5488..b9f65ffcc 100644 --- a/examples/search_r1/train_search_r1_agent_transition_ins.py +++ b/examples/search_r1/train_search_r1_agent_transition_ins.py @@ -159,6 +159,12 @@ def main() -> None: choices=["fast", "qwen", "llama"], help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", ) + parser.add_argument( + "--external-store-address", + type=str, + default="", + help="Connect to an external store instead of creating a new one in memory", + ) args = parser.parse_args() @@ -166,6 +172,7 @@ def main() -> None: config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} config = config_functions[args.config]() + config["external_store_address"]=args.external_store_address print(f"Starting training with '{args.config}' configuration...") From d458b26b3689d784535589f0e6f2f28e927bd0c1 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Thu, 4 Dec 2025 04:04:39 +0000 Subject: [PATCH 26/28] update external store (fix bug) --- .../search_r1/train_search_r1_agent_transition_ins.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/search_r1/train_search_r1_agent_transition_ins.py b/examples/search_r1/train_search_r1_agent_transition_ins.py index b9f65ffcc..a12029af9 100644 --- a/examples/search_r1/train_search_r1_agent_transition_ins.py +++ b/examples/search_r1/train_search_r1_agent_transition_ins.py @@ -137,11 +137,15 @@ def config_train_llama() -> Dict[str, Any]: return config -def train(config: Dict[str, Any]) -> None: +def train(config: Dict[str, Any], external_store_address: str = "") -> None: agent = SearchR1Agent() algorithm = agl.VERL(config) - trainer = agl.Trainer(n_runners=32, algorithm=algorithm) + if external_store_address: + store: Optional[agl.LightningStore] = agl.LightningStoreClient(external_store_address) + else: + store = None + trainer = agl.Trainer(n_runners=32, algorithm=algorithm, store=store) train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore @@ -172,11 +176,10 @@ def main() -> None: config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} config = config_functions[args.config]() - config["external_store_address"]=args.external_store_address print(f"Starting training with '{args.config}' configuration...") - train(config) + train(config, external_store_address=args.external_store_address) if __name__ == "__main__": From 951b16832854b3d60e2c4144a3965eba0f2a2d02 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Thu, 4 Dec 2025 06:12:44 +0000 Subject: [PATCH 27/28] update trajectory-strict and trajectory-tolerant --- agentlightning/verl/config.yaml | 8 +- agentlightning/verl/daemon.py | 65 ++-- agentlightning/verl/trainer.py | 14 +- analyze_mismatch.py | 329 ++++++++++++++++++ .../search_r1/debug_train_search_r1_agent.py | 2 +- 5 files changed, 380 insertions(+), 38 deletions(-) create mode 100644 analyze_mismatch.py diff --git a/agentlightning/verl/config.yaml b/agentlightning/verl/config.yaml index 6c0163ee1..eaa35ce97 100644 --- a/agentlightning/verl/config.yaml +++ b/agentlightning/verl/config.yaml @@ -20,7 +20,7 @@ actor_rollout_ref: path: pkg://agentlightning.verl.async_server name: PatchedvLLMServer trace_aggregator: - mode: transition # transition or trajectory - special_token_tolerance: 10 # only supported in trajectory mode, suggest to set as n_turns - string_tolerance: 20 # only supported in trajectory mode, suggest to set as n_turns * 2 - trajectory_max_length: 8192 # only supported in trajectory mode, suggest to set as n_turns * (max_response_length + max_prompt_length) + mode: transition # transition, trajectory-strict, or trajectory-tolerant + special_token_tolerance: 10 # only supported in trajectory-tolerant mode, suggest to set as n_turns + string_tolerance: 20 # only supported in trajectory-tolerant mode, suggest to set as n_turns * 2 + trajectory_max_length: 8192 # supported in two trajectory modes, suggest to set as n_turns * (max_response_length + max_prompt_length) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 5682c2869..8feeae407 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -31,7 +31,7 @@ ] -def logged_startswith(full_ids, prefix_ids, tokenizer): +def strict_startswith_with_log(full_ids, prefix_ids, tokenizer): template_mismatch, retoken_mismatch, others_mismatch = False, False, False if full_ids[:len(prefix_ids)] == prefix_ids: merge = True @@ -67,7 +67,7 @@ def _none_special_token_sequence(ids): # case 2: template_mismatch == True, retoken_mismatch == False, others_mismatch == False, merge == False if not ((not template_mismatch and not retoken_mismatch and not others_mismatch and merge) \ or (template_mismatch and not retoken_mismatch and not others_mismatch and not merge)): - with open("bad_case_jiahang.log", "a+") as f: + with open("mismatch_log/bad_case_unexpected.log", "a+") as f: print("-" * 20, file=f) print("full_ids:", file=f) print(full_ids, file=f) @@ -80,23 +80,23 @@ def _none_special_token_sequence(ids): # log data, only for debug testing def log_mismatch_detail(template_mismatch, retoken_mismatch, others_mismatch, full_ids, prefix_ids): if template_mismatch: - with open("template_mismatch.log", "a+") as f: + with open("mismatch_log/template_mismatch.log", "a+") as f: print("-" * 20, file=f) print(full_ids, file=f) print(prefix_ids, file=f) if retoken_mismatch: - with open("retoken_mismatch.log", "a+") as f: + with open("mismatch_log/retoken_mismatch.log", "a+") as f: print("-" * 20, file=f) print(full_ids, file=f) print(prefix_ids, file=f) if others_mismatch: - with open("others_mismatch.log", "a+") as f: + with open("mismatch_log/others_mismatch.log", "a+") as f: print("-" * 20, file=f) print(full_ids, file=f) print(prefix_ids, file=f) -def fuzzy_startswith(full_ids, prefix_ids, tokenizer, special_token_tolerance=0, string_tolerance=0): +def tolerant_startswith(full_ids, prefix_ids, tokenizer, special_token_tolerance=0, string_tolerance=0): def _special_token_sequence(ids): return [id for id in ids if id in tokenizer.all_special_ids] @@ -876,8 +876,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, data_id_list.append(sample_info["data_id"]) rollout_id_list.append(rollout_id) turn_index_list.append(turn_index) - - elif self.trace_aggregator.mode == "trajectory": + elif self.trace_aggregator.mode.startswith("trajectory"): response_mask_list: List[List[int]] = [] unmerged_count: int = 0 # only for debug template_mismatch_count, retoken_mismatch_count, others_mismatch_count = 0, 0, 0 @@ -894,25 +893,37 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, {"nxt_turn": trace["prompt_ids"][:] + trace["response_ids"][:], "cur": current_context[:]} ) response_per_turn_list.append(len(trace["response_ids"])) - template_mismatch, retoken_mismatch, others_mismatch, merged = logged_startswith( - trace["prompt_ids"] + trace["response_ids"], - current_context, - self.tokenizer, - ) - template_mismatch_count += int(template_mismatch) - retoken_mismatch_count += int(retoken_mismatch) - others_mismatch_count += int(others_mismatch) + if self.trace_aggregator.mode == "trajectory-strict": + template_mismatch, retoken_mismatch, others_mismatch, merged = strict_startswith_with_log( + trace["prompt_ids"] + trace["response_ids"], + current_context, + self.tokenizer, + ) + template_mismatch_count += int(template_mismatch) + retoken_mismatch_count += int(retoken_mismatch) + others_mismatch_count += int(others_mismatch) + if not merged: + log_mismatch_detail( # log data, only for debug testing + template_mismatch, + retoken_mismatch, + others_mismatch, + trace["prompt_ids"] + trace["response_ids"], + current_context, + ) + elif self.trace_aggregator.mode == "trajectory-tolerant": + merged = tolerant_startswith( + trace["prompt_ids"] + trace["response_ids"], + current_context, + self.tokenizer, + special_token_tolerance=self.trace_aggregator.special_token_tolerance, + string_tolerance=self.trace_aggregator.string_tolerance, + ) + else: + raise ValueError(f"Unknown trace_aggregator mode: {self.trace_aggregator.mode}") if merged: current_context = trace["prompt_ids"] + trace["response_ids"] current_merged_trace_idx.append(turn_index) else: - log_mismatch_detail( # log data, only for debug testing - template_mismatch, - retoken_mismatch, - others_mismatch, - trace["prompt_ids"] + trace["response_ids"], - current_context, - ) merged_trace_idx.append(current_merged_trace_idx) current_merged_trace_idx = [turn_index] current_context = trace["prompt_ids"] + trace["response_ids"] @@ -938,7 +949,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, final_sample = sample_info["trace_list"][current_merged_trace_idx[-1]] response_ids = final_sample["prompt_ids"][prompt_length:] + final_sample["response_ids"] if len(response_ids) != len(accum_response_ids): # only for debug testing - with open("response_ids_num_mismatch.log", "a+") as f: + with open("mismatch_log/response_ids_num_mismatch.log", "a+") as f: print("-" * 20, file=f) print(response_ids, file=f) print(accum_response_ids, file=f) @@ -987,7 +998,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, batch_response_ids = torch.LongTensor(response_ids_list).to(device) response_attention_mask = torch.LongTensor(response_attention_mask_list).to(device) response_mask = ( - torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.mode == "trajectory" else None + torch.LongTensor(response_mask_list).to(device) if self.trace_aggregator.mode.startswith("trajectory") else None ) # Concatenate prompts and responses to form the full sequence @@ -1016,7 +1027,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "position_ids": position_ids, "is_drop_mask": is_drop_mask, "token_level_scores": token_level_scores.contiguous(), - **({"response_mask": response_mask} if self.trace_aggregator.mode == "trajectory" else {}), + **({"response_mask": response_mask} if self.trace_aggregator.mode.startswith("trajectory") else {}), }, batch_size=n_transition, ) @@ -1038,7 +1049,7 @@ def get_train_data_batch(self, max_prompt_length: int, max_response_length: int, "training/template_mismatch_triplets": template_mismatch_count, "training/retoken_mismatch_triplets": retoken_mismatch_count, "training/others_mismatch_triplets": others_mismatch_count, - } if self.trace_aggregator.mode == "trajectory" else {}), + } if self.trace_aggregator.mode.startswith("trajectory") else {}), } # Add non-tensor data for advantage calculation and logging diff --git a/agentlightning/verl/trainer.py b/agentlightning/verl/trainer.py index c8fe75c66..9d044cae7 100644 --- a/agentlightning/verl/trainer.py +++ b/agentlightning/verl/trainer.py @@ -157,15 +157,17 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = True, suffix: str def log_step_for_mismatch_detail(step: int) -> None: - with open("template_mismatch.log", "a+") as f: + import os + os.makedirs("mismatch_log", exist_ok=True) + with open("mismatch_log/template_mismatch.log", "a+") as f: print("-" * 10 + f" Step {step}" + "-" * 10, file=f) - with open("retoken_mismatch.log", "a+") as f: + with open("mismatch_log/retoken_mismatch.log", "a+") as f: print("-" * 10 + f" Step {step}" + "-" * 10, file=f) - with open("others_mismatch.log", "a+") as f: + with open("mismatch_log/others_mismatch.log", "a+") as f: print("-" * 10 + f" Step {step}" + "-" * 10, file=f) - with open("response_ids_num_mismatch.log", "a+") as f: + with open("mismatch_log/response_ids_num_mismatch.log", "a+") as f: print("-" * 10 + f" Step {step}" + "-" * 10, file=f) - with open("bad_case_jiahang.log", "a+") as f: + with open("mismatch_log/bad_case_unexpected.log", "a+") as f: print("-" * 10 + f" Step {step}" + "-" * 10, file=f) @@ -234,7 +236,7 @@ def _train_step(self, batch_dict: dict) -> dict: batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( max_prompt_length=self.config.data.max_prompt_length, max_response_length=self.config.actor_rollout_ref.rollout.trace_aggregator.trajectory_max_length \ - if self.config.actor_rollout_ref.rollout.trace_aggregator.mode == "trajectory" else \ + if self.config.actor_rollout_ref.rollout.trace_aggregator.mode.startswith("trajectory") else \ self.config.data.max_response_length, device=gen_batch.batch["fake_ids"].device, ) diff --git a/analyze_mismatch.py b/analyze_mismatch.py new file mode 100644 index 000000000..2b9d0ba24 --- /dev/null +++ b/analyze_mismatch.py @@ -0,0 +1,329 @@ +import tokenizers +from transformers import AutoTokenizer +model_dir = 'meta-llama/Llama-3.2-3B-Instruct' +tok = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=True) + +count = 0 + +def find_min_diff_lengths(list_a, list_b): + """ + 计算两个列表中不匹配的连续子序列的最少长度。 + 这通过找出最长公共子序列 (LCS) 实现,然后返回两个列表中非 LCS 部分的长度。 + + Args: + list_a (list): 第一个列表。 + list_b (list): 第二个列表。 + + Returns: + tuple: 一个元组 (diff_a_len, diff_b_len), + 分别表示 list_a 和 list_b 中与 LCS 不匹配部分的长度。 + """ + len_a = len(list_a) + len_b = len(list_b) + + # 1. 初始化动态规划表 (DP Table) + # dp[i][j] 存储 list_a[:i] 和 list_b[:j] 的 LCS 长度 + # 尺寸为 (len_a + 1) x (len_b + 1) + dp = [[0] * (len_b + 1) for _ in range(len_a + 1)] + + # 2. 填充 DP 表 + for i in range(1, len_a + 1): + for j in range(1, len_b + 1): + if list_a[i - 1] == list_b[j - 1]: + # 如果当前元素匹配,LCS 长度加 1 + dp[i][j] = dp[i - 1][j - 1] + 1 + else: + # 如果不匹配,取 (排除 list_a[i-1]) 和 (排除 list_b[j-1]) 中的较大 LCS 长度 + dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) + + # 3. 获取最长公共子序列 (LCS) 的长度 + lcs_length = dp[len_a][len_b] + + # 4. 计算差异部分的长度 + # 差异部分长度 = 总长度 - LCS 长度 + diff_a_len = len_a - lcs_length + diff_b_len = len_b - lcs_length + + return (diff_a_len, diff_b_len) + +# # --- 示例调用 --- +# list_1 = [1, 2, 3, 5, 6, 9] +# list_2 = [1, 2, 4, 3, 9] + +# result = find_min_diff_lengths(list_1, list_2) +# print(f"列表 A: {list_1}") +# print(f"列表 B: {list_2}") +# print(f"差异部分的最少长度 (A, B): {result}") + +# # 示例 2 +# list_a = ["A", "B", "C", "D", "E"] +# list_b = ["A", "F", "D", "E"] +# # LCS: ["A", "D", "E"] -> 长度 3 +# # A 中非 LCS: ["B", "C"] -> 长度 2 +# # B 中非 LCS: ["F"] -> 长度 1 +# result_2 = find_min_diff_lengths(list_a, list_b) +# print("\n--- 示例 2 ---") +# print(f"列表 A: {list_a}") +# print(f"列表 B: {list_b}") +# print(f"差异部分的最少长度 (A, B): {result_2}") + +def logged_startswith(full_ids, prefix_ids, tokenizer): + template_mismatch, retoken_mismatch, others_mismatch = False, False, False + if full_ids[:len(prefix_ids)] == prefix_ids: + merge = True + return template_mismatch, retoken_mismatch, others_mismatch, merge + else: + merge = False + + def _special_token_sequence(ids): + return [id for id in ids if id in tokenizer.all_special_ids] + + def _none_special_token_sequence(ids): + return [id for id in ids if id not in tokenizer.all_special_ids] + + # First, handle special tokens + full_special_ids = _special_token_sequence(full_ids) + prefix_special_ids = _special_token_sequence(prefix_ids) + if len(full_special_ids) != len(prefix_special_ids) or sum(1 for a, b in zip(full_special_ids, prefix_special_ids) if a != b) > 0: + template_mismatch = True + + # Next, handle string content + full_content_ids = _none_special_token_sequence(full_ids) + prefix_content_ids = _none_special_token_sequence(prefix_ids) + full_string = tokenizer.decode(full_ids, skip_special_tokens=True) + prefix_string = tokenizer.decode(prefix_ids, skip_special_tokens=True) + if full_content_ids[:len(prefix_content_ids)] != prefix_content_ids and full_string.startswith(prefix_string): + retoken_mismatch = True + # diff_segments = find_all_diff_segments(full_content_ids[:len(prefix_content_ids)], prefix_content_ids) + # if len(diff_segments) == 1: + # diff_a, diff_b = diff_segments[0] + # diff_string_a = tokenizer.decode(diff_a, skip_special_tokens=True) + # diff_string_b = tokenizer.decode(diff_b, skip_special_tokens=True) + # if diff_string_a == " max_j: + return False # no possible prefix length + + prev_start = max(0, 0 - string_tolerance) + prev_end = min(n, 0 + string_tolerance) + prev = [j for j in range(prev_start, prev_end + 1)] + + for j_idx, j in enumerate(range(prev_start, prev_end + 1)): + if min_j <= j <= max_j and prev[j_idx] <= string_tolerance: + return True + + for i in range(1, m + 1): + # valid j range for this row + start_j = max(0, i - string_tolerance) + end_j = min(n, i + string_tolerance) + cur_len = end_j - start_j + 1 + cur = [0] * cur_len + + for idx, j in enumerate(range(start_j, end_j + 1)): + del_cost = None + prev_start = max(0, (i - 1) - string_tolerance) + prev_end = min(n, (i - 1) + string_tolerance) + if prev_start <= j <= prev_end: + del_cost = prev[j - prev_start] + 1 + else: + del_cost = abs((i - 1) - j) + 1 # safe over-approximation + + ins_cost = None + if j - 1 >= start_j: + ins_cost = cur[idx - 1] + 1 + else: + ins_cost = abs(i - (j - 1)) + 1 + + sub_cost = None + if prev_start <= (j - 1) <= prev_end: + sub_cost = prev[(j - 1) - prev_start] + (0 if prefix_string[i - 1] == full_string[j - 1] else 1) + else: + sub_cost = abs((i - 1) - (j - 1)) + (0 if prefix_string[i - 1] == full_string[j - 1] else 1) + + cur[idx] = min(del_cost, ins_cost, sub_cost) + + for idx, j in enumerate(range(start_j, end_j + 1)): + if min_j <= j <= max_j and cur[idx] <= string_tolerance: + return True + prev = cur + return False + + +def find_all_diff_segments(a, b): + i = j = 0 + n1, n2 = len(a), len(b) + diffs = [] + curr_a, curr_b = [], [] + + while i < n1 or j < n2: + # 如果两个列表都没结束并且元素相同 → diff 结束(如果在记录) + if i < n1 and j < n2 and a[i] == b[j]: + if curr_a or curr_b: + diffs.append((curr_a, curr_b)) + curr_a, curr_b = [], [] + i += 1 + j += 1 + continue + + # 下面是元素不同的情况,需要归类到 diff + if i < n1: + curr_a.append(a[i]) + if j < n2: + curr_b.append(b[j]) + i += 1 if i < n1 else 0 + j += 1 if j < n2 else 0 + + # 结束时如果还有 diff 段 + if curr_a or curr_b: + diffs.append((curr_a, curr_b)) + + return diffs + + +import re +from typing import List, Dict, Any, Tuple + +def parse_step_data(lines: List[str]) -> List[Tuple[List[float], List[float]]]: + res = [] + def str_to_float_list(s: str) -> List[float]: + numbers = re.findall(r'(\d+\.?\d*)', s) + return [int(n) for n in numbers] + data_lines = [line.strip() for line in lines if line.strip() and not line.strip().startswith('---')] + for i in range(0, len(data_lines), 2): + res.append((str_to_float_list(data_lines[i]), str_to_float_list(data_lines[i+1]))) + return res + +def process_txt_file(file_path: str) -> Dict[str, Any]: + data: Dict[str, Any] = {} + with open(file_path, 'r', encoding='utf-8') as f: + content = f.readlines() + + current_step = -1 + current_step_lines: List[str] = [] + + for line in content: + line = line.strip() + step_match = re.match(r'---------- Step (\d+)----------', line) + if step_match: + if current_step != -1: + step_key = f"step_{current_step}" + data[step_key] = parse_step_data(current_step_lines) + + current_step = int(step_match.group(1)) + current_step_lines = [] + continue + + if current_step == -1 and line: + if line.startswith('[') or line.startswith('---'): + current_step = 0 + + if current_step != -1: + current_step_lines.append(line) + + if current_step != -1: + step_key = f"step_{current_step}" + data[step_key] = parse_step_data(current_step_lines) + + return data + +result_data = process_txt_file("/home/jiahangxu/teamdrive/search_r1/mismatch_logs/Llama-3.2-3B-Instruct/searchr1_minibatch256_runner32_trajectory_synced/retoken_mismatch.log") + +for step, values in result_data.items(): + print("----------", step, "--------", len(values)) + with open("retoken_mismatch.log", "a+") as f: + print("----------", step, "--------", len(values), file=f) + for item in values: + record = logged_startswith(item[0], item[1], tok) + # print(record[1], record[3]) + # assert record[1] == True and record[3] == False + if not (record[1] == True and record[3] == False): + import pdb; pdb.set_trace() + print("Mismatch found:") + print("Full IDs: ", item[0]) + print("Prefix IDs: ", item[1]) + print(f"template_mismatch: {record[0]}, retoken_mismatch: {record[1]}, others_mismatch: {record[2]}, merge: {record[3]}") + print("Finished step:", step) + import pdb; pdb.set_trace() + + # TODO: for retoken mismatch,计算所有不同的diff tokens的数量占整体数量的比例 + \ No newline at end of file diff --git a/examples/search_r1/debug_train_search_r1_agent.py b/examples/search_r1/debug_train_search_r1_agent.py index a986b1039..8c9908209 100644 --- a/examples/search_r1/debug_train_search_r1_agent.py +++ b/examples/search_r1/debug_train_search_r1_agent.py @@ -42,7 +42,7 @@ } }, "trace_aggregator": { - "mode": "trajectory", + "mode": "trajectory-strict", "trajectory_max_length": 34384, } }, From 5a5d54fbbf417f0f5cd63eee61c6366100ba0fa0 Mon Sep 17 00:00:00 2001 From: jiahangxu Date: Thu, 4 Dec 2025 06:19:43 +0000 Subject: [PATCH 28/28] update training scrtpts with new args --- .../search_r1/train_search_r1_agent_ins.py | 4 +- ...ent.py => train_search_r1_agent_strict.py} | 4 +- .../train_search_r1_agent_tolerant.py | 179 ++++++++++++++++++ 3 files changed, 181 insertions(+), 6 deletions(-) rename examples/search_r1/{train_search_r1_agent.py => train_search_r1_agent_strict.py} (97%) create mode 100644 examples/search_r1/train_search_r1_agent_tolerant.py diff --git a/examples/search_r1/train_search_r1_agent_ins.py b/examples/search_r1/train_search_r1_agent_ins.py index e43eea0a1..09d2c34d2 100644 --- a/examples/search_r1/train_search_r1_agent_ins.py +++ b/examples/search_r1/train_search_r1_agent_ins.py @@ -42,9 +42,7 @@ } }, "trace_aggregator": { - "mode": "trajectory", - "special_token_tolerance": 0, - "string_tolerance": 0, + "mode": "trajectory-strict", "trajectory_max_length": 34384, } }, diff --git a/examples/search_r1/train_search_r1_agent.py b/examples/search_r1/train_search_r1_agent_strict.py similarity index 97% rename from examples/search_r1/train_search_r1_agent.py rename to examples/search_r1/train_search_r1_agent_strict.py index ef50b3249..7d56bd2c5 100644 --- a/examples/search_r1/train_search_r1_agent.py +++ b/examples/search_r1/train_search_r1_agent_strict.py @@ -42,9 +42,7 @@ } }, "trace_aggregator": { - "mode": "trajectory", - "special_token_tolerance": 0, - "string_tolerance": 0, + "mode": "trajectory-strict", # only allow token ids exact match "trajectory_max_length": 34384, } }, diff --git a/examples/search_r1/train_search_r1_agent_tolerant.py b/examples/search_r1/train_search_r1_agent_tolerant.py new file mode 100644 index 000000000..56eadea4d --- /dev/null +++ b/examples/search_r1/train_search_r1_agent_tolerant.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from __future__ import annotations + +import argparse +import os +from copy import deepcopy +from datetime import datetime +from typing import Any, Dict, Optional + +import pandas as pd +from search_r1_agent import SearchR1Agent + +import agentlightning as agl + +RL_TRAINING_CONFIG: Dict[str, Any] = { + "algorithm": { + "adv_estimator": "grpo", + "use_kl_in_reward": False, + }, + "data": { + "train_files": "data/train.parquet", + "val_files": "data/agent_test_50select.parquet", + "train_batch_size": 512, + "max_prompt_length": 6000, + "max_response_length": 4096, + "truncation": "error", + }, + "actor_rollout_ref": { + "rollout": { + "tensor_model_parallel_size": 1, + "n": 5, + "log_prob_micro_batch_size_per_gpu": 4, + "multi_turn": {"format": "hermes"}, + "name": "vllm", + "gpu_memory_utilization": 0.5, + "engine_kwargs": { + "vllm": { + "enable_auto_tool_choice": True, + "tool_call_parser": "hermes", + } + }, + "trace_aggregator": { + "mode": "trajectory-tolerant", + "special_token_tolerance": 0, # only allow re-token mismatches + "string_tolerance": 0, # only allow re-token mismatches + "trajectory_max_length": 34384, + } + }, + "actor": { + "ppo_mini_batch_size": 256, + "ppo_micro_batch_size_per_gpu": 4, + "optim": {"lr": 1e-6, "lr_warmup_steps_ratio": 0.95}, + "use_kl_loss": True, + "kl_loss_type": "low_var_kl", + "kl_loss_coef": 0.001, + "entropy_coeff": 0, + "clip_ratio_low": 0.2, + "clip_ratio_high": 0.3, + "fsdp_config": { + "param_offload": True, + "optimizer_offload": True, + }, + }, + "ref": { + "log_prob_micro_batch_size_per_gpu": 4, + "fsdp_config": {"param_offload": True}, + }, + "model": { + "path": "/home/aiscuser/Llama-3.2-3B", + "use_remove_padding": True, + "enable_gradient_checkpointing": True, + }, + }, + "trainer": { + "n_gpus_per_node": 8, + "val_before_train": True, + "critic_warmup": 0, + "logger": ["console", "wandb"], + "project_name": "AgentLightning-SearchR1-Base", + "experiment_name": "searchr1_minibatch256_runner32_trajectory_tolerant_synced", + "nnodes": 1, + "test_freq": 10, + "save_freq":10, + "total_epochs": 15, + "total_training_steps": 300, + "default_local_dir": "/mnt/teamdrive/search_r1/searchr1_checkpoints/Llama-3.2-3B/searchr1_minibatch256_runner32_trajectory_tolerant_synced/" + }, +} + + +def config_train_fast() -> Dict[str, Any]: + """A fast training run for CI testing purposes.""" + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + EXPERIMENT_NAME = f"searchr1_{timestamp}" + PROJECT_NAME = "AgentLightningCI" + + # Simulate writing to $GITHUB_OUTPUT if it’s set + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + f.write(f"project_name={PROJECT_NAME}\n") + f.write(f"run_name={EXPERIMENT_NAME}\n") + + print("Set environment variables:") + print(f"PROJECT_NAME={PROJECT_NAME}") + print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}") + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.6 + config["actor_rollout_ref"]["model"]["path"] = "Qwen/Qwen2.5-Coder-0.5B-Instruct" + config["data"]["val_files"] = "data/test_dev.parquet" + config["trainer"]["total_epochs"] = 1 + config["trainer"]["total_training_steps"] = 1 + config["trainer"]["experiment_name"] = EXPERIMENT_NAME + config["trainer"]["project_name"] = PROJECT_NAME + config["trainer"]["test_freq"] = 1 + return config + + +def config_train_qwen() -> Dict[str, Any]: + """A configuration for training with Qwen-2.5B.""" + + config = deepcopy(RL_TRAINING_CONFIG) + return config + + +def config_train_llama() -> Dict[str, Any]: + """A configuration for training with LLaMA-3.2-1B-Instruct. + + You will need a `HF_TOKEN` set to run with this config. + """ + + config = deepcopy(RL_TRAINING_CONFIG) + config["actor_rollout_ref"]["rollout"]["multi_turn"]["format"] = "llama3_json" + config["actor_rollout_ref"]["rollout"]["engine_kwargs"]["vllm"]["tool_call_parser"] = "llama3_json" + config["actor_rollout_ref"]["model"]["path"] = "/home/aiscuser/Llama-3.2-3B" + return config + + +def train(config: Dict[str, Any]) -> None: + + agent = SearchR1Agent() + algorithm = agl.VERL(config) + trainer = agl.Trainer(n_runners=32, algorithm=algorithm) + + train_data = pd.read_parquet(config["data"]["train_files"]).to_dict(orient="records") # type: ignore + val_data = pd.read_parquet(config["data"]["val_files"]).to_dict(orient="records") # type: ignore + trainer.fit(agent, train_dataset=train_data, val_dataset=val_data) # type: ignore + + +def main() -> None: + """Main function to parse arguments and run training.""" + parser = argparse.ArgumentParser( + description="Train an Search-R1 agent using different model configurations" + ) + + parser.add_argument( + "config", + choices=["fast", "qwen", "llama"], + help="Training configuration: 'fast' (CI testing), 'qwen' (Qwen-2.5-Coder-1.5B), 'llama' (LLaMA-3.2-3B-Instruct)", + ) + + args = parser.parse_args() + + # Get the appropriate configuration + config_functions = {"fast": config_train_fast, "qwen": config_train_qwen, "llama": config_train_llama} + + config = config_functions[args.config]() + + print(f"Starting training with '{args.config}' configuration...") + + train(config) + + +if __name__ == "__main__": + main()