Skip to content

Conversation

@IlyasMoutawwakil
Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil commented Dec 8, 2025

What does this PR do?

I have started experimenting with pure pytorch MoE implementations following the HF exporters PR while trying to find a traceable/exportable variant for onnx/openvino.

In this PR I copy the attn_implementation API into a similar moe_implementation API, and added two new implementations:

  • batched_mm (the exportable one) which uses torch.bmm, is fastest on single batch size, but also uses a lot of memory
  • grouped_mm (the pytorch custom kernel one) inspired from torchtitan's moe imp (using torch._grouped_mm), which is fast and uses just as much as eager (i'm kinda surprised with that tbh)

benchmark

An initial benchmark shows promising results on (A100), I know that the torch._grouped_mm uses bfloat16 or something under the hood, so these might not be apples to apples (i'm still looking for more references on this function and how to use it "equivalently")

MoE Implementations Benchmark

Benchmark script: bench.py
It uses qwen2_moe ("Qwen/Qwen1.5-MoE-A2.7B", float32) where latency and memory are for the forward pass / prefill

Batch Size: 1

Implementation Mean Latency (ms) Median Latency (ms) P90 Latency (ms) Peak Mem (MB)
eager 182.65 178.79 206.35 54628.51
batched_mm 48.22 48.21 48.31 55421.70
grouped_mm 55.55 55.55 55.98 54635.47

Batch Size: 4

Implementation Mean Latency (ms) Median Latency (ms) P90 Latency (ms) Peak Mem (MB)
eager 363.03 362.73 366.36 54669.29
batched_mm 181.18 181.12 181.63 58899.68
grouped_mm 85.78 85.72 86.18 54706.38

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@IlyasMoutawwakil IlyasMoutawwakil changed the title BMM MoE implementation batched and grouped MoE implementations Dec 8, 2025
# Important: torch._grouped_mm requires mat_a and mat_b to have strides that are multiples of 16
# still can't find a reference for this constraint but I had models failing if not respected
mat_a_up = current_states_g
mat_b_up = gate_up_proj.transpose(-2, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you transpose then maybe this one can be done in the online weight conversion

Comment on lines 179 to 182
if mat_a_up.stride(1) % 16 != 0:
mat_a_up = _make_stride_multiple_of(mat_a_up, 1, 16)
if mat_b_up.stride(1) % 16 != 0:
mat_b_up = _make_stride_multiple_of(mat_b_up, 1, 16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit weird to do this on each forward are we not wasting time there?


# --- Down projection per expert (grouped_mm) ---
mat_a_down = hidden_after_activation
mat_b_down = down_proj.transpose(-2, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the way we want v5 is to have "perfect" weights with the weight converter -> this can be don ein the weight converter

# Accumulate results back to the final_hidden_states using original token indices
final_hidden_states.index_add_(0, token_idx, out_per_sample.to(final_hidden_states.dtype))

return final_hidden_states
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this one more because it does not have the additional transpose.
How does it compare for different num experts setups?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing licence.

If we go down that road, which I like TBH, we should also add kernels supports + try at least for FP8 to see how this would work.

Because the other solution is to use use_hf_hub_kernel decorator as well but looks more cumbersome. So following FA2 we want to support kernels from the hub in this as well + quantization

This does look nice, for the bench can you add compile cases as well please?

Also we need to make sure this works with TP / EP

Comment on lines 101 to 122
def _pad_dim_end(t: torch.Tensor, dim: int, pad_elems: int):
if pad_elems == 0:
return t
new_shape = list(t.shape)
new_shape[dim] += pad_elems
padded = t.new_zeros(*new_shape)
idx = [slice(None)] * t.dim()
idx[dim] = slice(0, t.shape[dim])
padded[tuple(idx)] = t
return padded


def _make_stride_multiple_of(t: torch.Tensor, dim: int, multiple: int):
stride = t.stride(dim)
if stride % multiple == 0:
return t
elem_size = t.element_size()
align_elems = max(1, multiple // elem_size)
k = t.shape[dim]
pad = (-k) % align_elems
return _pad_dim_end(t, dim, pad)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any of these should in some way update the weights -> caching the changes we do as long as they are small might need us to have access to the state

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: dbrx, deepseek_v2, deepseek_v3, dots1, ernie4_5_moe, flex_olmo, glm4_moe, glm4v_moe, hunyuan_v1_moe, jamba, lfm2_moe, minimax, mixtral, olmoe, phimoe, qwen2_moe

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42697&sha=ea2f39

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants