-
Notifications
You must be signed in to change notification settings - Fork 188
Description
I am doing the Tunix Kaggle competition but run into bugs when setting the TRAIN_MICRO_BATCH_SIZE to 1 in thisexample notebook.
tunix/tunix/generate/sampler.py
Lines 328 to 330 in 1fe360d
| input_mask = jnp.ones_like(token_buffer, dtype=jnp.bool_) | |
| token_buffer = token_buffer.at[:, :num_input_tokens].set(all_input_ids) | |
| input_mask = input_mask.at[:, :num_input_tokens].set( |
Description
When invoking tunix.generate.sampler.Sampler with batch_size=1, the generated output is always truncated after just a few tokens; the call never reaches <end_of_turn> even though the same prompt works fine with batch_size=2. Digging into the sampler code shows that:
token_bufferis initialized ininit_sample_state()as(batch_size, total_sampling_steps)(e.g., (1, 896))._sample()is supposed to write each new token viatoken_buffer = token_buffer.at[:, decoding_step + 1].set(next_token_candidate).- During the JAX
while_loop(_decode_fn),_sample_steprepeatedly calls_sample(), but whenbatch_size=1the buffer never grows beyond the first 2–4 entries. - By the time token extraction runs (
__call__, lines 735-782),token_bufferhas shape(1, 2)(or(1, 4)), sotoken_buffer[max_prompt_length:]only contains a few tokens andfind_first_eos_idxreturns immediately, resulting in truncated outputs. Running the same logic withbatch_size=2produces a normal-length buffer and completion.
I added debug host-callbacks around _sample_step to log token_buffer.shape before/after each iteration and the logs confirm that the .at[:, decoding_step + 1].set(...) update silently stops writing when there’s a single example. This makes the sampler unusable when TRAIN_MICRO_BATCH_SIZE=1, yet it works fine for larger batches.
Steps to reproduce
- Instantiate the sampler used in
gemma2-grpo-original-demo.ipynb, e.g. the Tunix transformer + tokenizer + cache config. - Call
sampler(input_strings=[prompt], ...)withmax_generation_stepslarge enough to reach<end_of_turn>, using greedy sampling (temperature=0.0001,top_k=1). - Inspect
result.tokens[0]/result.text[0]– the output is only a few tokens long and often doesn’t end with<end_of_turn>. - Repeat step 2 with
input_strings=[prompt, prompt](batch_size=2) and observe a normal-length completion.
Expected behavior
token_buffer should stay (1, total_sampling_steps) and _sample() should keep writing tokens until EOS, just like it does when batch_size=2.
Actual behavior
token_buffer stops at ~2-4 slots, so the extractor only returns the first few tokens and the sample is truncated. This happens inside _sample() / _sample_step, meaning the loop never expands the buffer for single-item batches.
Proposed fix
Investigate how _sample() / _sample_step behaves under JAX when batch_size=1. The .at update may be optimized out or shape-inferred incorrectly. Ensuring the token_buffer update runs (and that token_buffer remains (batch_size, total_sampling_steps)) should fix the bug. If necessary, add a test for batch_size=1 to catch regressions.