@@ -29,6 +29,14 @@ class ModelImpl(str, Enum):
2929 TRANSFORMERS = "transformers"
3030
3131
32+ class MoEBackend (str , Enum ):
33+ """Backend for Mixture of Experts computation."""
34+
35+ EPMOE = "epmoe" # Native Expert Parallel MoE (default)
36+ FUSED = "fused" # Fused Kernel (TPU-optimized)
37+ AUTO = "auto" # Automatically select based on ep_size
38+
39+
3240class ModelConfig :
3341 def __init__ (
3442 self ,
@@ -44,6 +52,7 @@ def __init__(
4452 model_impl : str | ModelImpl = ModelImpl .AUTO ,
4553 quantization : str | None = None ,
4654 model_layer_nums : int | None = None ,
55+ moe_backend : str | MoEBackend = MoEBackend .AUTO ,
4756 ) -> None :
4857
4958 self .model_path = model_path
@@ -53,6 +62,15 @@ def __init__(
5362 # if ep_size > 1, use ep moe, else use fused moe
5463 # TODO: support ep moe with ETP
5564 self .ep_size = 1
65+
66+ # Process MoE backend selection
67+ self .moe_backend = MoEBackend (moe_backend ) if isinstance (moe_backend , str ) else moe_backend
68+
69+ # Auto-select backend based on ep_size
70+ if self .moe_backend == MoEBackend .AUTO :
71+ # If ep_size > 1, use EPMoE (expert parallelism across devices)
72+ # Otherwise use Fused kernel (single-device TPU optimization)
73+ self .moe_backend = MoEBackend .EPMOE if self .ep_size > 1 else MoEBackend .FUSED
5674 # Parse args
5775 self .maybe_pull_model_tokenizer_from_remote ()
5876 self .model_override_args = json .loads (model_override_args )
@@ -176,6 +194,7 @@ def from_server_args(
176194 quantization = server_args .quantization ,
177195 model_impl = server_args .model_impl ,
178196 model_layer_nums = server_args .model_layer_nums ,
197+ moe_backend = server_args .moe_backend ,
179198 ** kwargs ,
180199 )
181200
0 commit comments