Skip to content

Commit 9c60856

Browse files
committed
fused moe layer, edit grok
1 parent 15d82d3 commit 9c60856

File tree

13 files changed

+2646
-130
lines changed

13 files changed

+2646
-130
lines changed

python/sgl_jax/srt/configs/model_config.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ class ModelImpl(str, Enum):
2929
TRANSFORMERS = "transformers"
3030

3131

32+
class MoEBackend(str, Enum):
33+
"""Backend for Mixture of Experts computation."""
34+
35+
EPMOE = "epmoe" # Native Expert Parallel MoE (default)
36+
FUSED = "fused" # Fused Kernel (TPU-optimized)
37+
AUTO = "auto" # Automatically select based on ep_size
38+
39+
3240
class ModelConfig:
3341
def __init__(
3442
self,
@@ -44,6 +52,7 @@ def __init__(
4452
model_impl: str | ModelImpl = ModelImpl.AUTO,
4553
quantization: str | None = None,
4654
model_layer_nums: int | None = None,
55+
moe_backend: str | MoEBackend = MoEBackend.AUTO,
4756
) -> None:
4857

4958
self.model_path = model_path
@@ -53,6 +62,15 @@ def __init__(
5362
# if ep_size > 1, use ep moe, else use fused moe
5463
# TODO: support ep moe with ETP
5564
self.ep_size = 1
65+
66+
# Process MoE backend selection
67+
self.moe_backend = MoEBackend(moe_backend) if isinstance(moe_backend, str) else moe_backend
68+
69+
# Auto-select backend based on ep_size
70+
if self.moe_backend == MoEBackend.AUTO:
71+
# If ep_size > 1, use EPMoE (expert parallelism across devices)
72+
# Otherwise use Fused kernel (single-device TPU optimization)
73+
self.moe_backend = MoEBackend.EPMOE if self.ep_size > 1 else MoEBackend.FUSED
5674
# Parse args
5775
self.maybe_pull_model_tokenizer_from_remote()
5876
self.model_override_args = json.loads(model_override_args)
@@ -176,6 +194,7 @@ def from_server_args(
176194
quantization=server_args.quantization,
177195
model_impl=server_args.model_impl,
178196
model_layer_nums=server_args.model_layer_nums,
197+
moe_backend=server_args.moe_backend,
179198
**kwargs,
180199
)
181200

python/sgl_jax/srt/kernels/fused_moe/__init__.py

Whitespace-only changes.

python/sgl_jax/srt/kernels/fused_moe/v1/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)