-
Notifications
You must be signed in to change notification settings - Fork 31.4k
batched and grouped MoE implementations #42697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # Important: torch._grouped_mm requires mat_a and mat_b to have strides that are multiples of 16 | ||
| # still can't find a reference for this constraint but I had models failing if not respected | ||
| mat_a_up = current_states_g | ||
| mat_b_up = gate_up_proj.transpose(-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you transpose then maybe this one can be done in the online weight conversion
src/transformers/integrations/moe.py
Outdated
| if mat_a_up.stride(1) % 16 != 0: | ||
| mat_a_up = _make_stride_multiple_of(mat_a_up, 1, 16) | ||
| if mat_b_up.stride(1) % 16 != 0: | ||
| mat_b_up = _make_stride_multiple_of(mat_b_up, 1, 16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bit weird to do this on each forward are we not wasting time there?
|
|
||
| # --- Down projection per expert (grouped_mm) --- | ||
| mat_a_down = hidden_after_activation | ||
| mat_b_down = down_proj.transpose(-2, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, the way we want v5 is to have "perfect" weights with the weight converter -> this can be don ein the weight converter
| # Accumulate results back to the final_hidden_states using original token indices | ||
| final_hidden_states.index_add_(0, token_idx, out_per_sample.to(final_hidden_states.dtype)) | ||
|
|
||
| return final_hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this one more because it does not have the additional transpose.
How does it compare for different num experts setups?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing licence.
If we go down that road, which I like TBH, we should also add kernels supports + try at least for FP8 to see how this would work.
Because the other solution is to use use_hf_hub_kernel decorator as well but looks more cumbersome. So following FA2 we want to support kernels from the hub in this as well + quantization
This does look nice, for the bench can you add compile cases as well please?
Also we need to make sure this works with TP / EP
src/transformers/integrations/moe.py
Outdated
| def _pad_dim_end(t: torch.Tensor, dim: int, pad_elems: int): | ||
| if pad_elems == 0: | ||
| return t | ||
| new_shape = list(t.shape) | ||
| new_shape[dim] += pad_elems | ||
| padded = t.new_zeros(*new_shape) | ||
| idx = [slice(None)] * t.dim() | ||
| idx[dim] = slice(0, t.shape[dim]) | ||
| padded[tuple(idx)] = t | ||
| return padded | ||
|
|
||
|
|
||
| def _make_stride_multiple_of(t: torch.Tensor, dim: int, multiple: int): | ||
| stride = t.stride(dim) | ||
| if stride % multiple == 0: | ||
| return t | ||
| elem_size = t.element_size() | ||
| align_elems = max(1, multiple // elem_size) | ||
| k = t.shape[dim] | ||
| pad = (-k) % align_elems | ||
| return _pad_dim_end(t, dim, pad) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any of these should in some way update the weights -> caching the changes we do as long as they are small might need us to have access to the state
|
[For maintainers] Suggested jobs to run (before merge) run-slow: dbrx, deepseek_v2, deepseek_v3, dots1, ernie4_5_moe, flex_olmo, glm4_moe, glm4v_moe, hunyuan_v1_moe, jamba, lfm2_moe, minimax, mixtral, olmoe, phimoe, qwen2_moe |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42697&sha=ea2f39 |
What does this PR do?
I have started experimenting with pure pytorch MoE implementations following the HF exporters PR while trying to find a traceable/exportable variant for onnx/openvino.
In this PR I copy the
attn_implementationAPI into a similarmoe_implementationAPI, and added two new implementations:batched_mm(the exportable one) which usestorch.bmm, is fastest on single batch size, but also uses a lot of memorygrouped_mm(the pytorch custom kernel one) inspired from torchtitan's moe imp (usingtorch._grouped_mm), which is fast and uses just as much as eager (i'm kinda surprised with that tbh)benchmark
An initial benchmark shows promising results on (A100), I know that the
torch._grouped_mmuses bfloat16 or something under the hood, so these might not be apples to apples (i'm still looking for more references on this function and how to use it "equivalently")MoE Implementations Benchmark
Benchmark script: bench.py
It uses qwen2_moe ("Qwen/Qwen1.5-MoE-A2.7B", float32) where latency and memory are for the forward pass / prefill
Batch Size: 1
Batch Size: 4
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.