Skip to content

Single-item batches truncate sampler token_buffer instead of populating it #809

@Cascoopman

Description

@Cascoopman

I am doing the Tunix Kaggle competition but run into bugs when setting the TRAIN_MICRO_BATCH_SIZE to 1 in thisexample notebook.

input_mask = jnp.ones_like(token_buffer, dtype=jnp.bool_)
token_buffer = token_buffer.at[:, :num_input_tokens].set(all_input_ids)
input_mask = input_mask.at[:, :num_input_tokens].set(

Description

When invoking tunix.generate.sampler.Sampler with batch_size=1, the generated output is always truncated after just a few tokens; the call never reaches <end_of_turn> even though the same prompt works fine with batch_size=2. Digging into the sampler code shows that:

  • token_buffer is initialized in init_sample_state() as (batch_size, total_sampling_steps) (e.g., (1, 896)).
  • _sample() is supposed to write each new token via token_buffer = token_buffer.at[:, decoding_step + 1].set(next_token_candidate).
  • During the JAX while_loop (_decode_fn), _sample_step repeatedly calls _sample(), but when batch_size=1 the buffer never grows beyond the first 2–4 entries.
  • By the time token extraction runs (__call__, lines 735-782), token_buffer has shape (1, 2) (or (1, 4)), so token_buffer[max_prompt_length:] only contains a few tokens and find_first_eos_idx returns immediately, resulting in truncated outputs. Running the same logic with batch_size=2 produces a normal-length buffer and completion.

I added debug host-callbacks around _sample_step to log token_buffer.shape before/after each iteration and the logs confirm that the .at[:, decoding_step + 1].set(...) update silently stops writing when there’s a single example. This makes the sampler unusable when TRAIN_MICRO_BATCH_SIZE=1, yet it works fine for larger batches.

Steps to reproduce

  1. Instantiate the sampler used in gemma2-grpo-original-demo.ipynb, e.g. the Tunix transformer + tokenizer + cache config.
  2. Call sampler(input_strings=[prompt], ...) with max_generation_steps large enough to reach <end_of_turn>, using greedy sampling (temperature=0.0001, top_k=1).
  3. Inspect result.tokens[0] / result.text[0] – the output is only a few tokens long and often doesn’t end with <end_of_turn>.
  4. Repeat step 2 with input_strings=[prompt, prompt] (batch_size=2) and observe a normal-length completion.

Expected behavior

token_buffer should stay (1, total_sampling_steps) and _sample() should keep writing tokens until EOS, just like it does when batch_size=2.

Actual behavior

token_buffer stops at ~2-4 slots, so the extractor only returns the first few tokens and the sample is truncated. This happens inside _sample() / _sample_step, meaning the loop never expands the buffer for single-item batches.

Proposed fix

Investigate how _sample() / _sample_step behaves under JAX when batch_size=1. The .at update may be optimized out or shape-inferred incorrectly. Ensuring the token_buffer update runs (and that token_buffer remains (batch_size, total_sampling_steps)) should fix the bug. If necessary, add a test for batch_size=1 to catch regressions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions