diff --git a/.gitignore b/.gitignore index 5cb5b516a..aa8607992 100644 --- a/.gitignore +++ b/.gitignore @@ -241,3 +241,4 @@ CLAUDE.md #gemini code .gemini-clipboard +GEMINI.md diff --git a/benchmark/fused_moe/README.md b/benchmark/fused_moe/README.md new file mode 100644 index 000000000..e3fae9d62 --- /dev/null +++ b/benchmark/fused_moe/README.md @@ -0,0 +1,428 @@ +# Fused MoE vs EP MoE Benchmark + +Comprehensive layer-level benchmark comparing **FusedEPMoE** (Pallas TPU kernel) vs **EPMoE** (GMM kernel) implementations. + +## Features + +- ✅ **Layer-level testing** with synthetic weights +- ✅ **Controlled token distribution** scenarios (random, balanced, imbalanced) +- ✅ **Load imbalance metrics** (max expert load / average load) +- ✅ **Distributed configurations** (EP, TP) with multi-node support +- ✅ **JAX profiling** support via `jax.profiler` +- ✅ **HuggingFace config loading** or manual configuration +- ✅ **Multiple output formats** (CSV for plotting, Markdown for viewing) + +## Quick Start + +### 1. Simple Test (Manual Config) + +```bash +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --num-experts 8 \ + --num-experts-per-tok 2 \ + --hidden-size 1024 \ + --intermediate-size 4096 \ + --num-tokens 512 \ + --scenarios random +``` + +### 2. Using HuggingFace Model Config + +```bash +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --ep-size 8 \ + --num-tokens 1024 2048 4096 \ + --scenarios random balanced imbalanced +``` + +### 3. With Profiling + +```bash +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --ep-size 8 \ + --num-tokens 1024 2048 \ + --scenarios random balanced imbalanced \ + --profile \ + --profile-dir ./profiles/qwen_benchmark +``` + +### 4. Multi-Node Setup + +**Node 0:** +```bash +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --dist-init-addr 10.0.0.1:12345 \ + --nnodes 2 \ + --node-rank 0 \ + --ep-size 16 +``` + +**Node 1:** +```bash +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --dist-init-addr 10.0.0.1:12345 \ + --nnodes 2 \ + --node-rank 1 \ + --ep-size 16 +``` + +## Command-Line Arguments + +### Model Configuration + +Choose **one** of: + +- `--model-path PATH`: Load config from HuggingFace model (downloads or loads from local) +- `--manual-config`: Use manual configuration (requires additional args below) + +**Manual configuration options** (required if `--manual-config` is used): +- `--num-experts INT`: Number of experts +- `--num-experts-per-tok INT`: Top-k value +- `--hidden-size INT`: Hidden dimension +- `--intermediate-size INT`: Intermediate dimension +- `--activation {silu,gelu,swigluoai}`: Activation function (default: silu) + +### Distributed Configuration + +- `--ep-size INT`: Expert parallel size (default: 1) +- `--tp-size INT`: Total number of devices to use (default: 1) + - The actual tensor parallel size is computed as `tp_size // ep_size` in MoE layers +- `--dist-init-addr ADDR`: Distributed init address (e.g., `10.0.0.1:12345`) +- `--nnodes INT`: Number of nodes (default: 1) +- `--node-rank INT`: Current node rank (default: 0) + +### Benchmark Parameters + +- `--num-tokens INT [INT ...]`: List of token counts to test (default: 512 1024 2048) +- `--scenarios {random,balanced,imbalanced} [...]`: Scenarios to test (default: all) +- `--imbalance-factor FLOAT`: Target imbalance for "imbalanced" scenario (default: 3.0) + - **Definition**: `max_expert_load / avg_expert_load` + - Balanced scenario always targets ~1.0 (perfect balance) + - Imbalanced scenario uses this factor (e.g., 3.0 = busiest expert gets 3x average) +- `--warmup-iters INT`: Warmup iterations (default: 1, only need one for JAX JIT) +- `--benchmark-iters INT`: Benchmark iterations (default: 10) + +### Profiling + +- `--profile`: Enable JAX profiler +- `--profile-dir PATH`: Profile output directory (default: ./profiles) + +### Output + +- `--output-format {csv,markdown,both}`: Output format (default: both) +- `--output-file PATH`: Output file base path (default: ./benchmark_results) +- `--verbose`: Enable verbose logging + +## Scenarios + +### 1. Random + +Uniform random router logits from N(0, 1). Results in natural imbalance ~1.2-1.5x. + +**Use case**: Realistic scenario simulating natural token distribution. + +### 2. Balanced + +Engineered logits using round-robin assignment to ensure equal expert distribution. Target imbalance: ~1.0x (perfect balance). + +**Use case**: Best-case scenario for MoE performance. + +### 3. Imbalanced + +Exponential distribution favoring first few experts. Controlled by `--imbalance-factor` (default: 3.0). + +**Use case**: Worst-case scenario to test robustness under load imbalance. + +## Output Formats + +### CSV Format + +```csv +implementation,scenario,num_tokens,ep_size,tp_size,num_experts,num_experts_per_tok, +latency_mean_ms,latency_std_ms,latency_p50_ms,latency_p95_ms,latency_p99_ms, +max_load,min_load,avg_load,max_imbalance,throughput_tok_per_sec +fused,random,1024,8,1,60,8,2.3456,0.1234,2.3000,2.5000,2.6000,150,130,140.5,1.07,436543.21 +epmoe,random,1024,8,1,60,8,3.1234,0.2345,3.1000,3.4000,3.5000,150,130,140.5,1.07,327891.23 +``` + +**Columns:** +- `implementation`: "fused" or "epmoe" +- `scenario`: "random", "balanced", or "imbalanced" +- `num_tokens`: Number of tokens in the test +- `ep_size`: Expert parallel size +- `tp_size`: Total number of devices (actual tensor parallel = tp_size // ep_size) +- `num_experts`, `num_experts_per_tok`: MoE configuration +- `latency_*_ms`: Latency statistics in milliseconds +- `max_load`, `min_load`, `avg_load`: Expert load distribution +- `max_imbalance`: Maximum imbalance ratio (max_load / avg_load) +- `throughput_tok_per_sec`: Throughput in tokens per second + +### Markdown Format + +```markdown +# MoE Benchmark Results +**Configuration:** 60 experts, top-8, EP=8, TP=1 + +## Scenario: balanced, Tokens: 1024 + +| Metric | Fused MoE | EP MoE | Speedup | +|--------|-----------|--------|---------| +| Mean Latency (ms) | 2.3456 | 3.1234 | 1.33x | +| P95 Latency (ms) | 2.5678 | 3.4567 | - | +| Throughput (tok/s) | 436.5 | 327.8 | - | +| Max Imbalance | 1.05x | 1.05x | - | + +## Scenario: imbalanced, Tokens: 1024 + +| Metric | Fused MoE | EP MoE | Speedup | +|--------|-----------|--------|---------| +| Mean Latency (ms) | 2.8901 | 3.6789 | 1.27x | +| Max Imbalance | 3.12x | 3.12x | - | +``` + +## Imbalance Metrics + +The benchmark reports **imbalance factor** defined as: + +``` +imbalance_factor = max_expert_load / avg_expert_load +``` + +**Examples:** +- `1.0x`: Perfect balance (all experts receive equal tokens) +- `1.5x`: Mild imbalance (busiest expert gets 50% more than average) +- `3.0x`: High imbalance (busiest expert gets 3x more than average) + +**Reported metrics:** +- `max_load`: Maximum tokens assigned to any single expert +- `min_load`: Minimum tokens assigned to any single expert +- `avg_load`: Average tokens per expert +- `max_imbalance`: Imbalance factor (max_load / avg_load) + +## Profiling + +When `--profile` is enabled, JAX profiler traces are saved for each scenario/implementation combination. + +### View Traces + +Trace files can be loaded and visualized from: + +1. **Perfetto UI**: https://ui.perfetto.dev/ (any browser) +2. **Chrome Tracing**: chrome://tracing (Chrome browser only) + +Open the trace file from `/_tokens_/plugins/profile/*/trace.json.gz` + +If browser cannot open trace file due to its large size, reduce `--num-tokens` or `--benchmark-iters` to generate smaller traces. + +### View Traces with Tensorboard + +```bash +tensorboard --logdir= +# Open the displayed URL in browser +``` + +### View Traces with XProf + +[XProf](https://github.com/openxla/xprof) includes a suite of tools for JAX, TensorFlow, and PyTorch/XLA. + +```bash +# Install XProf (nightly version) +pip install xprof-nightly + +# Without TensorBoard: +xprof --logdir= --port=6006 + +# With TensorBoard: +tensorboard --logdir= +``` + +## Implementation Details + +### Weight Equivalence + +The benchmark ensures mathematical equivalence between FusedEPMoE and EPMoE: + +```python +# EPMoE format +wi_0: (num_experts, hidden_size, intermediate_size) # gate projection +wi_1: (num_experts, hidden_size, intermediate_size) # up projection +wo: (num_experts, intermediate_size, hidden_size) # down projection + +# FusedEPMoE format (transposed!) +w1: (num_experts, 2, intermediate_size, hidden_size) # [gate, up] fused +w2: (num_experts, intermediate_size, hidden_size) # down projection + +# Transformation +w1[:, 0, :, :] = wi_0.transpose(0, 2, 1) # gate +w1[:, 1, :, :] = wi_1.transpose(0, 2, 1) # up +w2 = wo +``` + +### Router Logits + +Both implementations receive the same router logits, but: +- **FusedEPMoE**: Handles top-k selection internally +- **EPMoE**: Requires explicit `TopK` module call first + +## File Structure + +``` +benchmark/fused_moe/ +├── bench_fused_vs_epmoe.py # Main benchmark script +├── config_utils.py # Configuration loading and validation +├── synthetic_data.py # Synthetic data generation +├── benchmark_runner.py # Core benchmark execution +├── output_formatter.py # CSV and Markdown formatting +└── README.md # This file +``` + +## Troubleshooting + +### Error: "num_experts must be divisible by ep_size" + +Ensure `num_experts % ep_size == 0`. For example, if you have 60 experts, valid `ep_size` values are: 1, 2, 3, 4, 5, 6, 10, 12, 15, 20, 30, 60. + +### Error: "tp_size exceeds device count" + +`--tp-size` should equal the total number of devices you want to use. Check available devices with: + +```python +import jax +print(f"Available devices: {jax.device_count()}") +``` + +Example: For 4 devices with EP=4, use `--tp-size 4 --ep-size 4` (tp_actual will be 1). + +### Multi-node setup not working + +- Ensure `--dist-init-addr` is accessible from all nodes +- Verify firewall rules allow communication on the specified port +- Check that `--nnodes` and `--node-rank` are correct for each node + +## Benchmark Test Cases + +### Case 1: Qwen3-Coder-30B-A3B-Instruct (4x TPU v6e) + +**Configuration:** +- Model: `Qwen/Qwen3-Coder-30B-A3B-Instruct` +- Hardware: 4x TPU v6e chips +- Token counts: 1024, 2048, 4096, 8192, 16384 +- Scenarios: random, balanced, imbalanced + +**Note**: 4 chips cannot support `ep_size=8`. Recommended comparison: `ep_size=4, tp_size=4` (full EP, tp_actual=1) vs `ep_size=1, tp_size=4` (full TP, tp_actual=4). + +**Test commands:** + +```bash +# Test 1: 4 devices, EP=4, tp_actual=1 (Expert Parallelism) +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen3-Coder-30B-A3B-Instruct \ + --tp-size 4 \ + --ep-size 4 \ + --num-tokens 1024 2048 4096 8192 16384 \ + --scenarios random balanced imbalanced \ + --warmup-iters 1 \ + --benchmark-iters 10 \ + --output-file ./results/qwen3_ep4_tp1 + +# Test 2: 4 devices, EP=1, tp_actual=4 (Tensor Parallelism only) +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen3-Coder-30B-A3B-Instruct \ + --tp-size 4 \ + --ep-size 1 \ + --num-tokens 1024 2048 4096 8192 16384 \ + --scenarios random balanced imbalanced \ + --warmup-iters 1 \ + --benchmark-iters 10 \ + --output-file ./results/qwen3_ep1_tp4 +``` + +### Case 2: Grok2 (32 chips, 8 machines) + +**Configuration:** +- Model: Grok2 +- Hardware: 32 chips across 8 machines (4 chips per machine) +- Token counts: 1024, 2048, 4096, 8192, 16384 +- Scenarios: random, balanced, imbalanced + +**Test commands:** + +Run on each machine with different `--node-rank` (0-7): + +```bash +# Test 1: 32 devices, EP=8, tp_actual=4 (Expert Parallelism) +# Machine 0 (rank 0): +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path /path/to/grok2 \ + --tp-size 32 \ + --ep-size 8 \ + --dist-init-addr :12345 \ + --nnodes 8 \ + --node-rank 0 \ + --num-tokens 1024 2048 4096 8192 16384 \ + --scenarios random balanced imbalanced \ + --warmup-iters 1 \ + --benchmark-iters 10 \ + --output-file ./results/grok2_ep8_tp4 + +# Machines 1-7: Same command but change --node-rank to 1, 2, ..., 7 + +# Test 2: 32 devices, EP=1, tp_actual=32 (Tensor Parallelism only) +# Machine 0 (rank 0): +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path /path/to/grok2 \ + --tp-size 32 \ + --ep-size 1 \ + --dist-init-addr :12345 \ + --nnodes 8 \ + --node-rank 0 \ + --num-tokens 1024 2048 4096 8192 16384 \ + --scenarios random balanced imbalanced \ + --warmup-iters 1 \ + --benchmark-iters 10 \ + --output-file ./results/grok2_ep1_tp32 + +# Machines 1-7: Same command but change --node-rank to 1, 2, ..., 7 +``` + +**Multi-machine launcher script:** + +```bash +#!/bin/bash +# run_grok2_bench.sh +MASTER_IP="10.0.0.1" # Replace with actual master IP +NODE_RANK=${1:-0} + +python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path /path/to/grok2 \ + --tp-size 32 \ + --ep-size 8 \ + --dist-init-addr ${MASTER_IP}:12345 \ + --nnodes 8 \ + --node-rank $NODE_RANK \ + --num-tokens 1024 2048 4096 8192 16384 \ + --scenarios random balanced imbalanced \ + --warmup-iters 1 \ + --benchmark-iters 10 \ + --output-file ./results/grok2_ep8_tp4 + +# Run on each machine: bash run_grok2_bench.sh 0, bash run_grok2_bench.sh 1, ... +``` + +## Contributing + +To extend this benchmark: + +1. **Add new scenarios**: Edit `generate_router_logits()` in `synthetic_data.py` +2. **Add new metrics**: Modify `BenchmarkResult` in `benchmark_runner.py` +3. **Change output format**: Edit `output_formatter.py` + +## References + +- FusedEPMoE implementation: `sgl_jax/srt/layers/fused_moe.py` +- EPMoE implementation: `sgl_jax/srt/layers/moe.py` +- Main benchmark pattern: `sgl_jax/bench_one_batch.py` diff --git a/benchmark/fused_moe/bench_fused_vs_epmoe.py b/benchmark/fused_moe/bench_fused_vs_epmoe.py new file mode 100644 index 000000000..7d0110a26 --- /dev/null +++ b/benchmark/fused_moe/bench_fused_vs_epmoe.py @@ -0,0 +1,373 @@ +"""Benchmark script comparing FusedEPMoE vs EPMoE implementations. + +This script performs layer-level benchmarking with synthetic weights and controlled +token distribution scenarios (random, balanced, imbalanced). + +Example usage: + # Quick test + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --num-experts 8 --num-experts-per-tok 2 \ + --hidden-size 1024 --intermediate-size 4096 \ + --num-tokens 512 --scenarios random + + # Using HF model config + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --ep-size 8 --num-tokens 1024 2048 4096 \ + --scenarios random balanced imbalanced \ + --profile --profile-dir ./profiles/qwen + + # 4 GPUs with expert parallelism (tp=4, ep=4, tp_actual=1) + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --tp-size 4 --ep-size 4 + + # 8 GPUs with expert and tensor parallelism (tp=8, ep=4, tp_actual=2) + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ + --model-path Qwen/Qwen2.5-MoE-A2.7B \ + --tp-size 8 --ep-size 4 +""" + +import argparse +import os +import sys + +import jax + +# Add python directory to path for imports +benchmark_dir = os.path.dirname(os.path.abspath(__file__)) # benchmark/fused_moe +benchmark_root = os.path.dirname(benchmark_dir) # benchmark +project_root = os.path.dirname(benchmark_root) # sgl-jax +python_dir = os.path.join(project_root, "python") # sgl-jax/python +sys.path.insert(0, python_dir) +sys.path.insert(0, project_root) # For benchmark imports + +from benchmark.fused_moe.benchmark_runner import MoEBenchmarkRunner # noqa: E402 +from benchmark.fused_moe.config_utils import MoEBenchmarkConfig # noqa: E402 +from benchmark.fused_moe.output_formatter import save_results # noqa: E402 +from benchmark.fused_moe.synthetic_data import create_synthetic_weights # noqa: E402 + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Benchmark FusedEPMoE vs EPMoE implementations", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # Model configuration (mutually exclusive) + config_group = parser.add_mutually_exclusive_group(required=True) + config_group.add_argument( + "--model-path", + type=str, + help="Path or name of HuggingFace model to load config from", + ) + config_group.add_argument( + "--manual-config", + action="store_true", + help="Use manual configuration (requires --num-experts, etc.)", + ) + + # Manual configuration options + parser.add_argument("--num-experts", type=int, help="Number of experts") + parser.add_argument("--num-experts-per-tok", type=int, help="Top-k value") + parser.add_argument("--hidden-size", type=int, help="Hidden dimension") + parser.add_argument("--intermediate-size", type=int, help="Intermediate dimension") + parser.add_argument( + "--activation", + type=str, + default="silu", + choices=["silu", "gelu", "swigluoai"], + help="Activation function", + ) + + # Distributed configuration + parser.add_argument( + "--ep-size", + type=int, + default=1, + help="Expert parallel size (default: 1)", + ) + parser.add_argument( + "--tp-size", + type=int, + default=1, + help="Total number of devices to use (default: 1)", + ) + parser.add_argument( + "--dist-init-addr", + type=str, + help="Distributed initialization address (e.g., 10.0.0.1:12345)", + ) + parser.add_argument( + "--nnodes", + type=int, + default=1, + help="Number of nodes (default: 1)", + ) + parser.add_argument( + "--node-rank", + type=int, + default=0, + help="Current node rank (default: 0)", + ) + + # Benchmark parameters + parser.add_argument( + "--num-tokens", + type=int, + nargs="+", + default=[512, 1024, 2048], + help="List of token counts to test (default: 512 1024 2048)", + ) + parser.add_argument( + "--scenarios", + type=str, + nargs="+", + default=["random", "balanced", "imbalanced"], + choices=["random", "balanced", "imbalanced"], + help="Scenarios to test (default: all)", + ) + parser.add_argument( + "--imbalance-factor", + type=float, + default=3.0, + help="Target imbalance factor for 'imbalanced' scenario (default: 3.0)", + ) + parser.add_argument( + "--warmup-iters", + type=int, + default=1, + help="Warmup iterations (default: 1, only need one for JAX JIT)", + ) + parser.add_argument( + "--benchmark-iters", + type=int, + default=10, + help="Benchmark iterations (default: 10)", + ) + + # Profiling + parser.add_argument( + "--profile", + action="store_true", + help="Enable JAX profiler", + ) + parser.add_argument( + "--profile-dir", + type=str, + default="./profiles", + help="Profile output directory (default: ./profiles)", + ) + + # Output + parser.add_argument( + "--output-format", + type=str, + default="both", + choices=["csv", "markdown", "both"], + help="Output format (default: both)", + ) + parser.add_argument( + "--output-file", + type=str, + default="./benchmark_results", + help="Output file base path (default: ./benchmark_results)", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + # Validate manual config + if args.manual_config: + required_manual = ["num_experts", "num_experts_per_tok", "hidden_size", "intermediate_size"] + missing = [arg for arg in required_manual if getattr(args, arg) is None] + if missing: + parser.error( + f"--manual-config requires: {', '.join('--' + m.replace('_', '-') for m in missing)}" + ) + + return args + + +def setup_distributed(args: argparse.Namespace) -> None: + """Initialize JAX distributed environment if needed.""" + if args.nnodes > 1: + if not args.dist_init_addr: + raise ValueError("--dist-init-addr is required for multi-node setup") + + print(f"Initializing distributed: nnodes={args.nnodes}, rank={args.node_rank}") + jax.distributed.initialize( + coordinator_address=args.dist_init_addr, + num_processes=args.nnodes, + process_id=args.node_rank, + ) + print(f"Distributed initialized. Process rank: {jax.process_index()}") + + +def create_mesh(tp_size: int) -> jax.sharding.Mesh: + """ + Create JAX mesh for MoE execution using create_device_mesh. + + This follows the same logic as scheduler.py. The MoE layers (FusedEPMoE and EPMoE) + will internally compute world_size from the mesh and calculate the actual tensor + parallel size as: tp_actual = world_size // ep_size + + Args: + tp_size: Total number of devices to use + + Returns: + JAX mesh with (data, tensor) axes + """ + from sgl_jax.srt.utils.mesh_utils import create_device_mesh + + mesh = create_device_mesh( + ici_parallelism=[-1, tp_size], + dcn_parallelism=[1, 1], + ) + + return mesh + + +def load_or_create_config(args: argparse.Namespace) -> MoEBenchmarkConfig: + """Load configuration from model path or create from manual args.""" + if args.model_path: + print(f"Loading config from model: {args.model_path}") + config = MoEBenchmarkConfig.from_model_path( + args.model_path, + ep_size=args.ep_size, + tp_size=args.tp_size, + ) + else: + print("Using manual configuration") + config = MoEBenchmarkConfig( + num_experts=args.num_experts, + num_experts_per_tok=args.num_experts_per_tok, + hidden_size=args.hidden_size, + intermediate_size=args.intermediate_size, + activation=args.activation, + ep_size=args.ep_size, + tp_size=args.tp_size, + ) + + # Validate config + config.validate() + + if args.verbose: + print("\n" + str(config)) + + return config + + +def main(): + """Main execution flow.""" + args = parse_args() + + print("=" * 80) + print("MoE Benchmark: FusedEPMoE vs EPMoE") + print("=" * 80) + + # Setup distributed + setup_distributed(args) + + # Create mesh + print(f"\nCreating JAX mesh: tp_size={args.tp_size}, ep_size={args.ep_size}") + mesh = create_mesh(args.tp_size) + print(f"Mesh created with {len(mesh.devices.flatten())} devices") + print(f"Mesh shape: {mesh.shape}") + + # Load configuration + config = load_or_create_config(args) + + # Generate synthetic weights + print("\nGenerating synthetic weights...") + fused_weights, epmoe_weights = create_synthetic_weights(config, mesh) + print(f"Weights generated: w1={fused_weights['w1'].shape}, w2={fused_weights['w2'].shape}") + + # Initialize benchmark runner + print("\nInitializing benchmark runner...") + runner = MoEBenchmarkRunner( + config=config, + mesh=mesh, + warmup_iters=args.warmup_iters, + benchmark_iters=args.benchmark_iters, + verbose=args.verbose, + ) + + runner.initialize_layers(fused_weights, epmoe_weights) + + # Run benchmarks + print("\n" + "=" * 80) + print("Running Benchmarks") + print("=" * 80) + + all_results = [] + + for scenario in args.scenarios: + for num_tokens in args.num_tokens: + print(f"\n{'=' * 80}") + print(f"Scenario: {scenario}, Tokens: {num_tokens}") + print(f"{'=' * 80}") + + if args.profile: + # Profile each implementation separately + profile_dir_fused = os.path.join( + args.profile_dir, f"{scenario}_tokens{num_tokens}_fused" + ) + profile_dir_epmoe = os.path.join( + args.profile_dir, f"{scenario}_tokens{num_tokens}_epmoe" + ) + + os.makedirs(profile_dir_fused, exist_ok=True) + os.makedirs(profile_dir_epmoe, exist_ok=True) + + print(f"Profiling enabled: {profile_dir_fused}, {profile_dir_epmoe}") + + # Run with profiling + jax.profiler.start_trace(profile_dir_fused) + fused_result, _ = runner.benchmark_scenario( + scenario, num_tokens, args.imbalance_factor + ) + jax.profiler.stop_trace() + + jax.profiler.start_trace(profile_dir_epmoe) + _, epmoe_result = runner.benchmark_scenario( + scenario, num_tokens, args.imbalance_factor + ) + jax.profiler.stop_trace() + + all_results.extend([fused_result, epmoe_result]) + + else: + # Run without profiling + fused_result, epmoe_result = runner.benchmark_scenario( + scenario, num_tokens, args.imbalance_factor + ) + all_results.extend([fused_result, epmoe_result]) + + # Print summary + speedup = epmoe_result.latency_mean / fused_result.latency_mean + print("\nResults:") + print(f" FusedEPMoE: {fused_result.latency_mean:.4f} ms (mean)") + print(f" EPMoE: {epmoe_result.latency_mean:.4f} ms (mean)") + print(f" Speedup: {speedup:.2f}x") + print(f" Imbalance: {fused_result.max_imbalance:.2f}x") + + # Save results + print("\n" + "=" * 80) + print("Saving Results") + print("=" * 80) + + save_results(all_results, args.output_file, args.output_format) + + print("\n" + "=" * 80) + print("Benchmark Complete!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/benchmark/fused_moe/benchmark_runner.py b/benchmark/fused_moe/benchmark_runner.py new file mode 100644 index 000000000..6d13b9200 --- /dev/null +++ b/benchmark/fused_moe/benchmark_runner.py @@ -0,0 +1,347 @@ +"""Core benchmark execution for MoE implementations.""" + +import time +from dataclasses import dataclass +from types import SimpleNamespace +from typing import List, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import Mesh + +from benchmark.fused_moe.config_utils import MoEBenchmarkConfig +from benchmark.fused_moe.synthetic_data import ( + compute_imbalance_metrics, + create_hidden_states, + generate_router_logits, +) + + +@dataclass +class BenchmarkResult: + """Results from a single benchmark run.""" + + implementation: str # "fused" or "epmoe" + scenario: str # "random", "balanced", "imbalanced" + num_tokens: int + ep_size: int + tp_size: int + num_experts: int + num_experts_per_tok: int + + # Latency metrics (in milliseconds) + latency_mean: float + latency_std: float + latency_p50: float + latency_p95: float + latency_p99: float + latency_min: float + latency_max: float + + # Load imbalance metrics + max_load: int + min_load: int + avg_load: float + max_imbalance: float + + # Throughput + throughput: float # tokens/sec + + +class MoEBenchmarkRunner: + """Orchestrates benchmark execution for both MoE implementations.""" + + def __init__( + self, + config: MoEBenchmarkConfig, + mesh: Mesh, + warmup_iters: int = 1, + benchmark_iters: int = 10, + verbose: bool = False, + ): + """ + Initialize benchmark runner. + + Args: + config: Benchmark configuration + mesh: JAX mesh with (expert, tensor) axes + warmup_iters: Number of warmup iterations (default: 1 for JAX JIT) + benchmark_iters: Number of benchmark iterations + verbose: Enable verbose logging + """ + self.config = config + self.mesh = mesh + self.warmup_iters = warmup_iters + self.benchmark_iters = benchmark_iters + self.verbose = verbose + + # Create dummy config for layer initialization + self.dummy_config = self._create_dummy_config() + + # Will be initialized later + self.fused_moe = None + self.epmoe_topk = None + self.epmoe = None + + def _create_dummy_config(self): + """Create a minimal config object for MoE layer initialization.""" + return SimpleNamespace( + hidden_size=self.config.hidden_size, + ep_size=self.config.ep_size, + ) + + def initialize_layers(self, fused_weights: dict, epmoe_weights: dict): + """ + Initialize both MoE implementations with synthetic weights. + + Args: + fused_weights: Weights for FusedEPMoE (w1, w2) + epmoe_weights: Weights for EPMoE (wi_0, wi_1, wo) + """ + + from flax import nnx + + from sgl_jax.srt.layers.fused_moe import FusedEPMoE + from sgl_jax.srt.layers.moe import EPMoE, TopK + + dtype = jnp.bfloat16 if self.config.dtype == "bfloat16" else jnp.float32 + + # Initialize FusedEPMoE + if self.verbose: + print("Initializing FusedEPMoE...") + + self.fused_moe = nnx.eval_shape( + lambda: FusedEPMoE( + config=self.dummy_config, + num_experts=self.config.num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + ep_size=self.config.ep_size, + mesh=self.mesh, + intermediate_dim=self.config.intermediate_size, + weight_dtype=dtype, + dtype=dtype, + activation=self.config.activation, + renormalize_topk_logits=self.config.renormalize_topk_logits, + ) + ) + + # Overwrite weights with synthetic values + self.fused_moe.w1.value = fused_weights["w1"] + self.fused_moe.w2.value = fused_weights["w2"] + + # Initialize EPMoE components + if self.verbose: + print("Initializing EPMoE...") + + self.epmoe_topk = TopK( + topk=self.config.num_experts_per_tok, + renormalize=self.config.renormalize_topk_logits, + ) + + self.epmoe = nnx.eval_shape( + lambda: EPMoE( + config=self.dummy_config, + num_experts=self.config.num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + ep_size=self.config.ep_size, + mesh=self.mesh, + intermediate_dim=self.config.intermediate_size, + weight_dtype=dtype, + dtype=dtype, + activation=self.config.activation, + ) + ) + + # Overwrite weights + self.epmoe.wi_0.value = epmoe_weights["wi_0"] + self.epmoe.wi_1.value = epmoe_weights["wi_1"] + self.epmoe.wo.value = epmoe_weights["wo"] + + if self.verbose: + print("Layer initialization complete.") + + def run_fused_moe( + self, + hidden_states: jax.Array, + router_logits: jax.Array, + ) -> Tuple[jax.Array, List[float]]: + """ + Run FusedEPMoE forward pass and measure latency. + + Args: + hidden_states: Input tokens (num_tokens, hidden_size) + router_logits: Router logits (num_tokens, num_experts) + + Returns: + output: MoE output + latencies: List of latencies in milliseconds + """ + # Ensure inputs are on device + hidden_states = jax.device_put(hidden_states) + router_logits = jax.device_put(router_logits) + + # JIT compile + @jax.jit + def fused_forward(hidden_states, router_logits): + return self.fused_moe(hidden_states, router_logits) + + # Warmup (trigger JIT compilation) + if self.verbose: + print(f" Warmup: {self.warmup_iters} iteration(s)...") + for _ in range(self.warmup_iters): + output = fused_forward(hidden_states, router_logits) + jax.block_until_ready(output) + + # Benchmark + if self.verbose: + print(f" Benchmark: {self.benchmark_iters} iterations...") + latencies = [] + for _ in range(self.benchmark_iters): + start = time.perf_counter() + output = fused_forward(hidden_states, router_logits) + jax.block_until_ready(output) + latencies.append((time.perf_counter() - start) * 1000) # Convert to ms + + return output, latencies + + def run_epmoe( + self, + hidden_states: jax.Array, + router_logits: jax.Array, + ) -> Tuple[jax.Array, List[float], jax.Array]: + """ + Run EPMoE forward pass and measure latency. + + Args: + hidden_states: Input tokens (num_tokens, hidden_size) + router_logits: Router logits (num_tokens, num_experts) + + Returns: + output: MoE output + latencies: List of latencies in milliseconds + topk_ids: Expert assignments for imbalance calculation + """ + hidden_states = jax.device_put(hidden_states) + router_logits = jax.device_put(router_logits) + + # JIT compile (TopK + EPMoE together) + @jax.jit + def epmoe_forward(hidden_states, router_logits): + topk_weights, topk_ids = self.epmoe_topk(router_logits) + output = self.epmoe(hidden_states, topk_weights, topk_ids) + return output, topk_ids + + # Warmup + if self.verbose: + print(f" Warmup: {self.warmup_iters} iteration(s)...") + for _ in range(self.warmup_iters): + output, topk_ids = epmoe_forward(hidden_states, router_logits) + jax.block_until_ready(output) + + # Benchmark + if self.verbose: + print(f" Benchmark: {self.benchmark_iters} iterations...") + latencies = [] + for _ in range(self.benchmark_iters): + start = time.perf_counter() + output, topk_ids = epmoe_forward(hidden_states, router_logits) + jax.block_until_ready(output) + latencies.append((time.perf_counter() - start) * 1000) + + return output, latencies, topk_ids + + def benchmark_scenario( + self, + scenario: str, + num_tokens: int, + imbalance_factor: float = 3.0, + ) -> Tuple[BenchmarkResult, BenchmarkResult]: + """ + Run benchmark for a single scenario. + + Args: + scenario: "random", "balanced", or "imbalanced" + num_tokens: Number of tokens to test + imbalance_factor: Target imbalance for "imbalanced" scenario + + Returns: + fused_result: Results for FusedEPMoE + epmoe_result: Results for EPMoE + """ + if self.verbose: + print(f"\nBenchmarking scenario={scenario}, num_tokens={num_tokens}") + + # Generate data + hidden_states = create_hidden_states( + num_tokens, + self.config.hidden_size, + dtype=jnp.bfloat16 if self.config.dtype == "bfloat16" else jnp.float32, + ) + + router_logits = generate_router_logits( + num_tokens, + self.config.num_experts, + scenario, + num_experts_per_tok=self.config.num_experts_per_tok, + imbalance_factor=imbalance_factor, + ) + + # Run FusedEPMoE + if self.verbose: + print("Running FusedEPMoE...") + fused_output, fused_latencies = self.run_fused_moe(hidden_states, router_logits) + + # Run EPMoE + if self.verbose: + print("Running EPMoE...") + epmoe_output, epmoe_latencies, topk_ids = self.run_epmoe(hidden_states, router_logits) + + # Compute imbalance metrics (same for both since they use same router logits) + imbalance = compute_imbalance_metrics(topk_ids, self.config.num_experts) + + if self.verbose: + print(f" Max imbalance: {imbalance['max_imbalance']:.2f}x") + + # Create results + fused_result = self._create_result( + "fused", scenario, num_tokens, fused_latencies, imbalance + ) + epmoe_result = self._create_result( + "epmoe", scenario, num_tokens, epmoe_latencies, imbalance + ) + + return fused_result, epmoe_result + + def _create_result( + self, + implementation: str, + scenario: str, + num_tokens: int, + latencies: List[float], + imbalance: dict, + ) -> BenchmarkResult: + """Create BenchmarkResult from latency measurements.""" + latencies_array = np.array(latencies) + + return BenchmarkResult( + implementation=implementation, + scenario=scenario, + num_tokens=num_tokens, + ep_size=self.config.ep_size, + tp_size=self.config.tp_size, + num_experts=self.config.num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + latency_mean=float(np.mean(latencies_array)), + latency_std=float(np.std(latencies_array)), + latency_p50=float(np.percentile(latencies_array, 50)), + latency_p95=float(np.percentile(latencies_array, 95)), + latency_p99=float(np.percentile(latencies_array, 99)), + latency_min=float(np.min(latencies_array)), + latency_max=float(np.max(latencies_array)), + max_load=imbalance["max_load"], + min_load=imbalance["min_load"], + avg_load=imbalance["avg_load"], + max_imbalance=imbalance["max_imbalance"], + throughput=num_tokens / (np.mean(latencies_array) / 1000), # tokens/sec + ) diff --git a/benchmark/fused_moe/config_utils.py b/benchmark/fused_moe/config_utils.py new file mode 100644 index 000000000..4dc0ba892 --- /dev/null +++ b/benchmark/fused_moe/config_utils.py @@ -0,0 +1,105 @@ +"""Configuration utilities for MoE benchmark.""" + +from dataclasses import dataclass + + +@dataclass +class MoEBenchmarkConfig: + """Configuration for MoE benchmark.""" + + num_experts: int + num_experts_per_tok: int + hidden_size: int + intermediate_size: int + activation: str = "silu" + renormalize_topk_logits: bool = True + dtype: str = "bfloat16" + weight_dtype: str = "bfloat16" + + # Distributed config + ep_size: int = 1 + tp_size: int = 1 + + @classmethod + def from_model_path( + cls, model_path: str, ep_size: int = 1, tp_size: int = 1 + ) -> "MoEBenchmarkConfig": + """ + Load configuration from model path using AutoConfig.from_pretrained(). + + Downloads from HuggingFace if needed, or loads from local models directory. + + Args: + model_path: Path or name of HuggingFace model + ep_size: Expert parallel size + tp_size: Total number of devices to use + + Returns: + MoEBenchmarkConfig instance + """ + from transformers import AutoConfig + + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Extract MoE-specific parameters with fallbacks + num_experts = getattr( + hf_config, + "num_experts", + getattr(hf_config, "num_local_experts", 8), + ) + + num_experts_per_tok = getattr(hf_config, "num_experts_per_tok", 2) + + intermediate_size = getattr( + hf_config, + "moe_intermediate_size", + getattr(hf_config, "intermediate_size", 2048), + ) + + activation = getattr(hf_config, "hidden_act", "silu") + renormalize = getattr(hf_config, "norm_topk_prob", True) + + return cls( + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + hidden_size=hf_config.hidden_size, + intermediate_size=intermediate_size, + activation=activation, + renormalize_topk_logits=renormalize, + ep_size=ep_size, + tp_size=tp_size, + ) + + def validate(self) -> None: + """Validate configuration parameters.""" + if self.num_experts % self.ep_size != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by ep_size ({self.ep_size})" + ) + + if self.num_experts_per_tok > self.num_experts: + raise ValueError( + f"num_experts_per_tok ({self.num_experts_per_tok}) must be <= num_experts ({self.num_experts})" + ) + + # Check activation is supported + supported_activations = ["silu", "gelu", "swigluoai"] + if self.activation not in supported_activations: + raise ValueError( + f"Unsupported activation '{self.activation}'. Supported: {supported_activations}" + ) + + def __str__(self) -> str: + """String representation for logging.""" + return ( + f"MoEBenchmarkConfig(\n" + f" num_experts={self.num_experts},\n" + f" num_experts_per_tok={self.num_experts_per_tok},\n" + f" hidden_size={self.hidden_size},\n" + f" intermediate_size={self.intermediate_size},\n" + f" activation={self.activation},\n" + f" renormalize_topk_logits={self.renormalize_topk_logits},\n" + f" ep_size={self.ep_size},\n" + f" tp_size={self.tp_size}\n" + f")" + ) diff --git a/benchmark/fused_moe/output_formatter.py b/benchmark/fused_moe/output_formatter.py new file mode 100644 index 000000000..2980591fe --- /dev/null +++ b/benchmark/fused_moe/output_formatter.py @@ -0,0 +1,185 @@ +"""Output formatting for benchmark results (CSV and Markdown).""" + +import csv +import io +from typing import List + +from benchmark.fused_moe.benchmark_runner import BenchmarkResult + + +def format_as_csv(results: List[BenchmarkResult]) -> str: + """ + Format benchmark results as CSV. + + CSV Schema: + implementation,scenario,num_tokens,ep_size,tp_size,num_experts, + num_experts_per_tok,latency_mean_ms,latency_std_ms,latency_p50_ms, + latency_p95_ms,latency_p99_ms,latency_min_ms,latency_max_ms, + max_load,min_load,avg_load,max_imbalance,throughput_tok_per_sec + + Args: + results: List of benchmark results + + Returns: + CSV formatted string + """ + header = [ + "implementation", + "scenario", + "num_tokens", + "ep_size", + "tp_size", + "num_experts", + "num_experts_per_tok", + "latency_mean_ms", + "latency_std_ms", + "latency_p50_ms", + "latency_p95_ms", + "latency_p99_ms", + "latency_min_ms", + "latency_max_ms", + "max_load", + "min_load", + "avg_load", + "max_imbalance", + "throughput_tok_per_sec", + ] + + rows = [] + for r in results: + rows.append( + [ + r.implementation, + r.scenario, + r.num_tokens, + r.ep_size, + r.tp_size, + r.num_experts, + r.num_experts_per_tok, + f"{r.latency_mean:.4f}", + f"{r.latency_std:.4f}", + f"{r.latency_p50:.4f}", + f"{r.latency_p95:.4f}", + f"{r.latency_p99:.4f}", + f"{r.latency_min:.4f}", + f"{r.latency_max:.4f}", + r.max_load, + r.min_load, + f"{r.avg_load:.2f}", + f"{r.max_imbalance:.2f}", + f"{r.throughput:.2f}", + ] + ) + + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(header) + writer.writerows(rows) + return output.getvalue() + + +def format_as_markdown(results: List[BenchmarkResult]) -> str: + """ + Format benchmark results as Markdown table. + + Groups by scenario and num_tokens, shows side-by-side comparison. + + Args: + results: List of benchmark results + + Returns: + Markdown formatted string + """ + if not results: + return "# MoE Benchmark Results\n\nNo results to display.\n" + + # Group results by (scenario, num_tokens) + grouped = {} + for r in results: + key = (r.scenario, r.num_tokens) + if key not in grouped: + grouped[key] = {} + grouped[key][r.implementation] = r + + lines = [] + lines.append("# MoE Benchmark Results\n") + + # Add configuration info from first result + first_result = results[0] + lines.append( + f"**Configuration:** {first_result.num_experts} experts, " + f"top-{first_result.num_experts_per_tok}, " + f"EP={first_result.ep_size}, TP={first_result.tp_size}\n" + ) + + # Create tables for each scenario + for (scenario, num_tokens), impls in sorted(grouped.items()): + lines.append(f"\n## Scenario: {scenario}, Tokens: {num_tokens}\n") + + # Table header + lines.append("| Metric | Fused MoE | EP MoE | Speedup |") + lines.append("|--------|-----------|--------|---------|") + + fused = impls.get("fused") + epmoe = impls.get("epmoe") + + if fused and epmoe: + speedup = epmoe.latency_mean / fused.latency_mean + + lines.append( + f"| Mean Latency (ms) | {fused.latency_mean:.4f} | " + f"{epmoe.latency_mean:.4f} | {speedup:.2f}x |" + ) + lines.append( + f"| P95 Latency (ms) | {fused.latency_p95:.4f} | " f"{epmoe.latency_p95:.4f} | - |" + ) + lines.append( + f"| P99 Latency (ms) | {fused.latency_p99:.4f} | " f"{epmoe.latency_p99:.4f} | - |" + ) + lines.append( + f"| Throughput (tok/s) | {fused.throughput:.2f} | " f"{epmoe.throughput:.2f} | - |" + ) + lines.append( + f"| Max Imbalance | {fused.max_imbalance:.2f}x | " + f"{epmoe.max_imbalance:.2f}x | - |" + ) + elif fused: + lines.append(f"| Mean Latency (ms) | {fused.latency_mean:.4f} | N/A | - |") + lines.append(f"| P95 Latency (ms) | {fused.latency_p95:.4f} | N/A | - |") + lines.append(f"| Throughput (tok/s) | {fused.throughput:.2f} | N/A | - |") + lines.append(f"| Max Imbalance | {fused.max_imbalance:.2f}x | N/A | - |") + elif epmoe: + lines.append(f"| Mean Latency (ms) | N/A | {epmoe.latency_mean:.4f} | - |") + lines.append(f"| P95 Latency (ms) | N/A | {epmoe.latency_p95:.4f} | - |") + lines.append(f"| Throughput (tok/s) | N/A | {epmoe.throughput:.2f} | - |") + lines.append(f"| Max Imbalance | N/A | {epmoe.max_imbalance:.2f}x | - |") + + return "\n".join(lines) + + +def save_results( + results: List[BenchmarkResult], + output_file: str, + output_format: str = "both", +): + """ + Save benchmark results to files. + + Args: + results: List of benchmark results + output_file: Base output file path (without extension) + output_format: "csv", "markdown", or "both" + """ + if output_format in ("csv", "both"): + csv_content = format_as_csv(results) + csv_path = f"{output_file}.csv" + with open(csv_path, "w") as f: + f.write(csv_content) + print(f"CSV results saved to {csv_path}") + + if output_format in ("markdown", "both"): + md_content = format_as_markdown(results) + md_path = f"{output_file}.md" + with open(md_path, "w") as f: + f.write(md_content) + print(f"Markdown results saved to {md_path}") diff --git a/benchmark/fused_moe/synthetic_data.py b/benchmark/fused_moe/synthetic_data.py new file mode 100644 index 000000000..475a1bafb --- /dev/null +++ b/benchmark/fused_moe/synthetic_data.py @@ -0,0 +1,232 @@ +"""Synthetic data generation for MoE benchmark.""" + +from typing import Dict, Tuple + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh + +from benchmark.fused_moe.config_utils import MoEBenchmarkConfig + + +def create_synthetic_weights( + config: MoEBenchmarkConfig, + mesh: Mesh, + seed: int = 42, +) -> Tuple[Dict[str, jax.Array], Dict[str, jax.Array]]: + """ + Create synthetic weights for both FusedEPMoE and EPMoE. + + Ensures mathematical equivalence between the two implementations. + + Args: + config: Benchmark configuration + mesh: JAX mesh for sharding + seed: Random seed for reproducibility + + Returns: + fused_weights: Dictionary with keys "w1", "w2" for FusedEPMoE + epmoe_weights: Dictionary with keys "wi_0", "wi_1", "wo" for EPMoE + """ + key = jax.random.PRNGKey(seed) + key1, key2, key3 = jax.random.split(key, 3) + + dtype = jnp.bfloat16 if config.weight_dtype == "bfloat16" else jnp.float32 + + # Generate base weights in EPMoE format + # wi_0: gate projection (num_experts, hidden_size, intermediate_size) + wi_0 = ( + jax.random.normal( + key1, + (config.num_experts, config.hidden_size, config.intermediate_size), + dtype=dtype, + ) + * 0.02 + ) + + # wi_1: up projection (num_experts, hidden_size, intermediate_size) + wi_1 = ( + jax.random.normal( + key2, + (config.num_experts, config.hidden_size, config.intermediate_size), + dtype=dtype, + ) + * 0.02 + ) + + # wo: down projection (num_experts, intermediate_size, hidden_size) + wo = ( + jax.random.normal( + key3, + (config.num_experts, config.intermediate_size, config.hidden_size), + dtype=dtype, + ) + * 0.02 + ) + + # Create fused format for FusedEPMoE + # IMPORTANT: FusedEPMoE expects transposed weights! + # w1[:, 0, :, :] = wi_0.transpose(0, 2, 1) # gate + # w1[:, 1, :, :] = wi_1.transpose(0, 2, 1) # up + w1_gate = jnp.transpose(wi_0, (0, 2, 1)) # (num_experts, intermediate_size, hidden_size) + w1_up = jnp.transpose(wi_1, (0, 2, 1)) # (num_experts, intermediate_size, hidden_size) + w1 = jnp.stack([w1_gate, w1_up], axis=1) # (num_experts, 2, intermediate_size, hidden_size) + + w2 = wo # Same format for both + + fused_weights = { + "w1": w1, + "w2": w2, + } + + epmoe_weights = { + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + + return fused_weights, epmoe_weights + + +def generate_router_logits( + num_tokens: int, + num_experts: int, + scenario: str, + num_experts_per_tok: int = 2, + imbalance_factor: float = 3.0, + seed: int = 42, +) -> jax.Array: + """ + Generate router logits with different distribution patterns. + + Args: + num_tokens: Number of tokens + num_experts: Total number of experts + scenario: One of "random", "balanced", "imbalanced" + num_experts_per_tok: Top-k value (for balanced scenario) + imbalance_factor: Target max_load / avg_load for imbalanced scenario + seed: Random seed + + Returns: + router_logits: (num_tokens, num_experts) array of logits + + Scenarios: + - random: Uniform random N(0, 1) logits, natural imbalance ~1.2-1.5x + - balanced: Engineered to achieve ~1.0x imbalance (perfect balance) + - imbalanced: Skewed to achieve target imbalance_factor (default 3.0x) + """ + key = jax.random.PRNGKey(seed) + + if scenario == "random": + # Uniform random logits + router_logits = jax.random.normal(key, (num_tokens, num_experts), dtype=jnp.float32) + + elif scenario == "balanced": + # Round-robin assignment to ensure equal distribution + router_logits = jnp.ones((num_tokens, num_experts), dtype=jnp.float32) * -10.0 + + # Assign each token to experts in round-robin fashion + for token_idx in range(num_tokens): + # Calculate which experts this token should prefer + start_expert = (token_idx * num_experts_per_tok) % num_experts + for k in range(num_experts_per_tok): + expert_idx = (start_expert + k) % num_experts + router_logits = router_logits.at[token_idx, expert_idx].set(10.0) + + # Add small random noise for diversity (but keep assignment clear) + noise = jax.random.normal(key, router_logits.shape, dtype=jnp.float32) * 0.1 + router_logits = router_logits + noise + + elif scenario == "imbalanced": + # Create exponential distribution favoring first few experts + # Adjust temperature to achieve target imbalance_factor + + # Start with exponential decay + temperature = num_experts / (imbalance_factor * 2) # Heuristic + expert_base_logits = jnp.arange(num_experts, dtype=jnp.float32) + expert_base_logits = 10.0 * jnp.exp(-expert_base_logits / temperature) + + # Broadcast to all tokens with random variation + router_logits = jnp.tile(expert_base_logits, (num_tokens, 1)) + + # Add random noise to create variation + noise = jax.random.normal(key, router_logits.shape, dtype=jnp.float32) * 2.0 + router_logits = router_logits + noise + + else: + raise ValueError( + f"Unknown scenario '{scenario}'. Must be one of: random, balanced, imbalanced" + ) + + return router_logits + + +def compute_imbalance_metrics( + topk_ids: jax.Array, + num_experts: int, +) -> Dict[str, float]: + """ + Compute load imbalance metrics from expert assignments. + + Args: + topk_ids: (num_tokens, num_experts_per_tok) expert indices + num_experts: Total number of experts + + Returns: + Dictionary containing: + max_load: Maximum tokens assigned to any expert + min_load: Minimum tokens assigned to any expert + avg_load: Average tokens per expert + std_load: Standard deviation of load + max_imbalance: max_load / avg_load + min_imbalance: min_load / avg_load + load_distribution: Per-expert load counts (list) + + Example: + If avg_load = 100 and max_load = 300, then max_imbalance = 3.0 + This means the busiest expert received 3x more tokens than average. + """ + # Flatten topk_ids and count occurrences per expert + flat_ids = topk_ids.flatten() + expert_counts = jnp.bincount(flat_ids, length=num_experts) + + max_load = int(jnp.max(expert_counts)) + min_load = int(jnp.min(expert_counts)) + avg_load = float(jnp.mean(expert_counts)) + std_load = float(jnp.std(expert_counts)) + + # Compute imbalance ratios + max_imbalance = float(max_load / avg_load) if avg_load > 0 else float("inf") + min_imbalance = float(min_load / avg_load) if avg_load > 0 and min_load > 0 else 0.0 + + return { + "max_load": max_load, + "min_load": min_load, + "avg_load": avg_load, + "std_load": std_load, + "max_imbalance": max_imbalance, + "min_imbalance": min_imbalance, + "load_distribution": expert_counts.tolist(), + } + + +def create_hidden_states( + num_tokens: int, + hidden_size: int, + dtype: jnp.dtype = jnp.bfloat16, + seed: int = 42, +) -> jax.Array: + """ + Create synthetic input hidden states. + + Args: + num_tokens: Number of tokens + hidden_size: Hidden dimension + dtype: Data type + seed: Random seed + + Returns: + hidden_states: (num_tokens, hidden_size) array + """ + key = jax.random.PRNGKey(seed) + return jax.random.normal(key, (num_tokens, hidden_size), dtype=dtype) * 0.02 diff --git a/python/sgl_jax/bench_offline_throughput.py b/python/sgl_jax/bench_offline_throughput.py index 1fc811c69..6c8fd7c62 100644 --- a/python/sgl_jax/bench_offline_throughput.py +++ b/python/sgl_jax/bench_offline_throughput.py @@ -4,10 +4,10 @@ # Usage ## Sharegpt dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10 +python -m sgl_jax.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --num-prompts 10 ## Random dataset with default args -python -m sglang.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 +python -m sgl_jax.bench_offline_throughput --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dataset-name random --random-input 1024 --random-output 1024 """ import argparse diff --git a/python/sgl_jax/bench_one_batch.py b/python/sgl_jax/bench_one_batch.py index 05fa71b4b..6d0ae43b0 100644 --- a/python/sgl_jax/bench_one_batch.py +++ b/python/sgl_jax/bench_one_batch.py @@ -6,13 +6,15 @@ # Usage (latency test) ## with dummy weights: -python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy +python -m sgl_jax.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy ## sweep through multiple data points and store (append) the results in a jsonl file: -python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run +python -m sgl_jax.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch-size 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run ## run with profiling: -python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile +python -m sgl_jax.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch-size 1 12 14 --input-len 256 512 --profile +## run with custom prompts from file: +python -m sgl_jax.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --prompt-filename prompts.txt --batch-size 4 # Usage (correctness test): -python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct +python -m sgl_jax.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correctness-test ## Reference output (of the correctness test above, can be tpu dependent): input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]] @@ -43,6 +45,7 @@ """ import argparse +import copy import dataclasses import itertools import json @@ -74,6 +77,7 @@ class BenchArgs: batch_size: tuple[int] = (1,) input_len: tuple[int] = (1024,) output_len: tuple[int] = (16,) + prompt_filename: str = "" result_filename: str = "result.jsonl" correctness_test: bool = False # This is only used for correctness test @@ -88,6 +92,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--batch-size", type=int, nargs="+", default=BenchArgs.batch_size) parser.add_argument("--input-len", type=int, nargs="+", default=BenchArgs.input_len) parser.add_argument("--output-len", type=int, nargs="+", default=BenchArgs.output_len) + parser.add_argument("--prompt-filename", type=str, default=BenchArgs.prompt_filename) parser.add_argument("--result-filename", type=str, default=BenchArgs.result_filename) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) @@ -113,6 +118,17 @@ def from_cli_args(cls, args: argparse.Namespace): return cls(**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}) +def _read_prompts_from_file(prompt_file, rank_print): + """Read custom prompts from the file specified by `--prompt-filename`.""" + if not prompt_file: + return [] + if not os.path.exists(prompt_file): + rank_print(f"Custom prompt file {prompt_file} not found. Using default inputs...") + return [] + with open(prompt_file) as pf: + return pf.readlines() + + def load_model(server_args, port_args, tp_rank): rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None @@ -148,12 +164,16 @@ def load_model(server_args, port_args, tp_rank): return model_runner, tokenizer -def prepare_inputs_for_correctness_test(bench_args, tokenizer): - prompts = [ - "The capital of France is", - "The capital of the United Kindom is", - "Today is a sunny day and I like", - ] +def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts): + prompts = ( + custom_prompts + if custom_prompts + else [ + "The capital of France is", + "The capital of the United Kindom is", + "Today is a sunny day and I like", + ] + ) input_ids = [tokenizer.encode(p) for p in prompts] sampling_params = SamplingParams( temperature=0, @@ -190,8 +210,12 @@ def prepare_extend_inputs_for_correctness_test(bench_args, input_ids, reqs, mode return reqs -def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): - input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) +def prepare_synthetic_inputs_for_latency_test(batch_size, input_len, custom_inputs=None): + input_ids = ( + custom_inputs + if custom_inputs + else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32) + ) sampling_params = SamplingParams( temperature=0, max_new_tokens=(BenchArgs.output_len[0] if isinstance(BenchArgs.output_len, tuple) else 16), @@ -309,7 +333,8 @@ def correctness_test( model_runner, tokenizer = load_model(server_args, port_args, tp_rank) # Prepare inputs - input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) + custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print) + input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts) rank_print(f"\n{input_ids=}\n") if bench_args.cut_len > 0: @@ -476,12 +501,34 @@ def latency_test( rank_print("Benchmark ...") + custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print) + custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs] + custom_input_len = len(custom_inputs) + # Run the sweep result_list = [] for bs, il, ol in itertools.product( bench_args.batch_size, bench_args.input_len, bench_args.output_len ): - reqs = prepare_synthetic_inputs_for_latency_test(bs, il) + bs_aligned_inputs = [] + if custom_inputs: + if custom_input_len == bs: + bs_aligned_inputs = custom_inputs + elif custom_input_len > bs: + rank_print( + f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). " + f"Using the first {bs} prompts." + ) + bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs]) + else: + rank_print( + f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). " + f"Pad to the desired batch_size with the last prompt." + ) + bs_aligned_inputs = copy.deepcopy(custom_inputs) + bs_aligned_inputs.extend([bs_aligned_inputs[-1]] * (bs - custom_input_len)) + + reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs) ret = latency_test_run_once( bench_args.run_name, model_runner, diff --git a/python/sgl_jax/bench_one_batch_server.py b/python/sgl_jax/bench_one_batch_server.py index 7dcacfd38..2b7f791eb 100644 --- a/python/sgl_jax/bench_one_batch_server.py +++ b/python/sgl_jax/bench_one_batch_server.py @@ -5,10 +5,13 @@ It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths). Usage: -python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 +python3 -m sgl_jax.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8 -python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 -python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage +python3 -m sgl_jax.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 +python3 -m sgl_jax.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage + +# Use OpenAI-compatible API: +python3 -m sgl_jax.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --api-type openai """ import argparse @@ -45,6 +48,7 @@ class BenchArgs: show_report: bool = False profile: bool = False profile_by_stage: bool = False + api_type: str = "native" # "native" or "openai" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -70,6 +74,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--show-report", action="store_true") parser.add_argument("--profile", action="store_true") parser.add_argument("--profile-by-stage", action="store_true") + parser.add_argument( + "--api-type", + type=str, + default=BenchArgs.api_type, + choices=["native", "openai"], + help="API type to use: 'native' for /generate or 'openai' for /v1/completions", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -122,8 +133,12 @@ def run_one_case( tokenizer, profile: bool = False, profile_by_stage: bool = False, + api_type: str = "native", ): requests.post(url + "/flush_cache") + + # Determine whether to use text or input_ids based on API type + return_text = api_type == "openai" input_requests = sample_random_requests( input_len=input_len, output_len=output_len, @@ -132,7 +147,7 @@ def run_one_case( tokenizer=tokenizer, dataset_path="", random_sample=True, - return_text=False, + return_text=return_text, ) use_structured_outputs = False @@ -153,40 +168,96 @@ def run_one_case( profile_link: str = run_profile(url, 3, ["CPU", "GPU"], None, None, profile_by_stage) tic = time.perf_counter() - response = requests.post( - url + "/generate", - json={ - "input_ids": [req.prompt for req in input_requests], - "sampling_params": { + + if api_type == "openai": + # Use OpenAI API - send requests as a batch with n parameter + # Convert prompts to a single request with n > 1 if batch_size > 1 + if batch_size == 1: + request_data = { + "model": "default", + "prompt": input_requests[0].prompt, + "max_tokens": output_len, + "temperature": temperature, + "logprobs": 1 if return_logprob else None, + "stream": True, + } + else: + # For batch, we'll send the first prompt with n=batch_size + # Note: This assumes all prompts should be the same, which may not be ideal + # For different prompts, we'd need to send separate requests + request_data = { + "model": "default", + "prompt": [req.prompt for req in input_requests], + "max_tokens": output_len, "temperature": temperature, - "max_new_tokens": output_len, - "ignore_eos": True, - "json_schema": json_schema, - "stream_interval": stream_interval, + "logprobs": 1 if return_logprob else None, + "stream": True, + } + + response = requests.post( + url + "/v1/completions", + json=request_data, + stream=True, + ) + + # Parse OpenAI streaming response + ttft = 0.0 + first_token_received = False + total_chunks = 0 + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + if "error" in data: + raise RuntimeError(f"Request has failed. {data}.") + + # Track first token time (across all choices in batch) + if data.get("choices") and len(data["choices"]) > 0: + # Check if any choice has generated text + for choice in data["choices"]: + if choice.get("text") and not first_token_received: + ttft = time.perf_counter() - tic + first_token_received = True + break + total_chunks += 1 + else: + # Use native API + response = requests.post( + url + "/generate", + json={ + "input_ids": [req.prompt for req in input_requests], + "sampling_params": { + "temperature": temperature, + "max_new_tokens": output_len, + "ignore_eos": True, + "json_schema": json_schema, + "stream_interval": stream_interval, + }, + "return_logprob": return_logprob, + "stream": True, }, - "return_logprob": return_logprob, - "stream": True, - }, - stream=True, - ) + stream=True, + ) - # The TTFT of the last request in the batch - ttft = 0.0 - for chunk in response.iter_lines(decode_unicode=False): - chunk = chunk.decode("utf-8") - if chunk and chunk.startswith("data:"): - if chunk == "data: [DONE]": - break - data = json.loads(chunk[5:].strip("\n")) - if "error" in data: - raise RuntimeError(f"Request has failed. {data}.") - - assert ( - data["meta_info"]["finish_reason"] is None - or data["meta_info"]["finish_reason"]["type"] == "length" - ) - if data["meta_info"]["completion_tokens"] == 1: - ttft = time.perf_counter() - tic + # The TTFT of the last request in the batch + ttft = 0.0 + for chunk in response.iter_lines(decode_unicode=False): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]": + break + data = json.loads(chunk[5:].strip("\n")) + if "error" in data: + raise RuntimeError(f"Request has failed. {data}.") + + assert ( + data["meta_info"]["finish_reason"] is None + or data["meta_info"]["finish_reason"]["type"] == "length" + ) + if data["meta_info"]["completion_tokens"] == 1: + ttft = time.perf_counter() - tic latency = time.perf_counter() - tic input_throughput = batch_size * input_len / ttft @@ -262,6 +333,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): run_name="", result_filename="", tokenizer=tokenizer, + api_type=bench_args.api_type, ) print("=" * 8 + " Warmup End " + "=" * 8 + "\n") @@ -285,6 +357,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): run_name=bench_args.run_name, result_filename=bench_args.result_filename, tokenizer=tokenizer, + api_type=bench_args.api_type, ) ) @@ -309,6 +382,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): tokenizer=tokenizer, profile=bench_args.profile, profile_by_stage=bench_args.profile_by_stage, + api_type=bench_args.api_type, )[-1], ) ) diff --git a/python/sgl_jax/bench_serving.py b/python/sgl_jax/bench_serving.py index c96109d4a..8346fedc1 100644 --- a/python/sgl_jax/bench_serving.py +++ b/python/sgl_jax/bench_serving.py @@ -1656,7 +1656,7 @@ def __call__(self, parser, namespace, values, option_string=None): "--backend", type=str, choices=list(ASYNC_REQUEST_FUNCS.keys()), - default="sglang", + default="sgl-jax", help="Must specify a backend, depending on the LLM Inference Engine.", ) parser.add_argument( diff --git a/python/sgl_jax/profiler.py b/python/sgl_jax/profiler.py index a599684d1..31465b51d 100644 --- a/python/sgl_jax/profiler.py +++ b/python/sgl_jax/profiler.py @@ -2,7 +2,7 @@ Run live profiling. Usage: -python3 -m sglang.profiler +python3 -m sgl_jax.profiler """ import argparse @@ -14,7 +14,7 @@ import requests -PARENT_FOLDER = "/tmp/sglang-profile" +PARENT_FOLDER = "/tmp/sgl-jax-profile" def _run_profile( diff --git a/python/sgl_jax/srt/configs/model_config.py b/python/sgl_jax/srt/configs/model_config.py index fc4972ad2..534691a40 100644 --- a/python/sgl_jax/srt/configs/model_config.py +++ b/python/sgl_jax/srt/configs/model_config.py @@ -29,6 +29,14 @@ class ModelImpl(str, Enum): TRANSFORMERS = "transformers" +class MoEBackend(str, Enum): + """Backend for Mixture of Experts computation.""" + + EPMOE = "epmoe" # Native Expert Parallel MoE (default) + FUSED = "fused" # Fused Kernel (TPU-optimized) + AUTO = "auto" # Automatically select based on ep_size + + class ModelConfig: def __init__( self, @@ -44,6 +52,7 @@ def __init__( model_impl: str | ModelImpl = ModelImpl.AUTO, quantization: str | None = None, model_layer_nums: int | None = None, + moe_backend: str | MoEBackend = MoEBackend.AUTO, ) -> None: self.model_path = model_path @@ -53,6 +62,15 @@ def __init__( # if ep_size > 1, use ep moe, else use fused moe # TODO: support ep moe with ETP self.ep_size = 1 + + # Process MoE backend selection + self.moe_backend = MoEBackend(moe_backend) if isinstance(moe_backend, str) else moe_backend + + # Auto-select backend based on ep_size + if self.moe_backend == MoEBackend.AUTO: + # If ep_size > 1, use EPMoE (expert parallelism across devices) + # Otherwise use Fused kernel (single-device TPU optimization) + self.moe_backend = MoEBackend.EPMOE if self.ep_size > 1 else MoEBackend.FUSED # Parse args self.maybe_pull_model_tokenizer_from_remote() self.model_override_args = json.loads(model_override_args) @@ -176,6 +194,7 @@ def from_server_args( quantization=server_args.quantization, model_impl=server_args.model_impl, model_layer_nums=server_args.model_layer_nums, + moe_backend=server_args.moe_backend, **kwargs, ) diff --git a/python/sgl_jax/srt/kernels/fused_moe/__init__.py b/python/sgl_jax/srt/kernels/fused_moe/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sgl_jax/srt/kernels/fused_moe/v1/__init__.py b/python/sgl_jax/srt/kernels/fused_moe/v1/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sgl_jax/srt/kernels/fused_moe/v1/kernel.py b/python/sgl_jax/srt/kernels/fused_moe/v1/kernel.py new file mode 100644 index 000000000..b85b8d52e --- /dev/null +++ b/python/sgl_jax/srt/kernels/fused_moe/v1/kernel.py @@ -0,0 +1,1560 @@ +# Adapted from https://github.com/vllm-project/tpu-inference/blob/main/tpu_inference/kernels/fused_moe/v1/kernel.py +# Copyright 2025 The tpu-inference Authors. All rights reserved. +"""TPU-Friendly Fused Mixture of Experts (MoE) kernel.""" + +import functools + +import jax +import jax.numpy as jnp +from jax import lax +from jax._src import dtypes +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +P = jax.sharding.PartitionSpec + +cdiv = pl.cdiv + + +def align_to(x, a): + return cdiv(x, a) * a + + +def get_dtype_packing(dtype): + bits = dtypes.bit_width(dtype) + return 32 // bits + + +def broadcast_minor(src, shape): + if src.shape == shape: + return src + assert src.shape[:-1] == shape[:-1] + assert src.shape[-1] % 128 == 0 + target_minor = align_to(shape[-1], src.shape[-1]) + # no-op concatenation. + return jnp.concatenate([src for _ in range(target_minor // src.shape[-1])], axis=-1)[ + ..., : shape[-1] + ] + + +def swigluoai( + gate: jax.Array, up: jax.Array, *, alpha: float = 1.702, limit: float = 7.0 +) -> jax.Array: + """Activation used in some models such as GPT-OSS.""" + gate = jnp.clip(gate, a_max=limit) + up = jnp.clip(up, a_min=-limit, a_max=limit) + glu = gate * jax.nn.sigmoid(alpha * gate) + return (up + 1.0) * glu + + +def activation_fn(acc1, acc3, act_fn): + if act_fn == "silu": + return jax.nn.silu(acc1) * acc3 + elif act_fn == "gelu": + return jax.nn.gelu(acc1) * acc3 + elif act_fn == "swigluoai": + return swigluoai(acc1, acc3) + else: + raise RuntimeError(f"Unsupported activation function: {act_fn}") + + +def ref_moe( + tokens: jax.Array, # (num_tokens, hidden_size) + w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size) + w2: jax.Array, # (num_experts, intermediate_size, hidden_size) + gating_output: jax.Array, # (num_tokens, num_experts) + top_k: int, + *, + renormalize_topk_logits: bool = False, + act_fn: str = "silu", + subc_quant_wsz: int | None = None, + w1_scale: ( + jax.Array | None + ) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size) + w2_scale: ( + jax.Array | None + ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size) + b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size) + b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size) +): + n_tokens = tokens.shape[0] # num_tokens + + # Compute gating scores for all experts + gating_logits = jax.nn.softmax(gating_output, axis=-1) # [num_tokens, n_experts] + + # Select top-k experts per token + top_k_logits, top_k_indices = lax.top_k( + gating_logits, top_k + ) # [num_tokens, top_k], [num_tokens, top_k] + + if renormalize_topk_logits: + top_k_logits = top_k_logits / jnp.sum(top_k_logits, axis=-1, keepdims=True) + + t_outputs = [] + hidden_size, intermediate_size = w1.shape[-2:] + + # Process each token individually + for i in range(n_tokens): + curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size] + assigned_expert_ids = top_k_indices[i] # [top_k] - indices of selected experts for token i + tok_expert_act = [] + + # Process each selected expert for the current token + for expert_id in assigned_expert_ids: + # Get expert weights + expert_w1 = w1[expert_id, 0].astype(jnp.float32) + expert_w3 = w1[expert_id, 1].astype(jnp.float32) + if w1_scale is not None: + expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0], subc_quant_wsz, axis=0)[ + :hidden_size + ] + expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0], subc_quant_wsz, axis=0)[ + :hidden_size + ] + expert_weight_1 = jnp.concat( + [expert_w1, expert_w3], axis=-1 + ) # [hidden_size, 2 * intermediate_size] + expert_weight_2 = w2[expert_id].astype(jnp.float32) # [intermediate_size, hidden_size] + if w2_scale is not None: + expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0], subc_quant_wsz, axis=0)[ + :intermediate_size + ] + + # First linear layer with SwiGLU activation + gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size] + + # Split into gate and up projections for SwiGLU + gmm1_w1_proj, gmm1_w3_proj = jnp.split( + gmm_1_out, 2, axis=-1 + ) # [1, intermediate_size], [1, intermediate_size] + if b1 is not None: + gmm1_w1_proj += b1[expert_id : expert_id + 1, 0, 0] + gmm1_w3_proj += b1[expert_id : expert_id + 1, 1, 0] + + # Apply gated activation: activation(gate) * up + act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn) + + # Second linear layer (down projection) + gmm_2_out = act @ expert_weight_2 # [1, hidden_size] + if b2 is not None: + gmm_2_out += b2[expert_id : expert_id + 1, 0] + tok_expert_act.append(gmm_2_out) + + # Combine outputs from all selected experts + experts_act = jnp.concatenate(tok_expert_act, axis=0) # [top_k, hidden_size] + + # Weighted sum using top-k gating weights + top_k_weights = top_k_logits[i] # [top_k] + top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1] + weighted_output = jnp.sum( + experts_act * top_k_weights, axis=0, keepdims=True + ) # [1, hidden_size] + + t_outputs.append(weighted_output.astype(tokens.dtype)) + + return jnp.concatenate(t_outputs, axis=0) # [actual_num_tokens, hidden_size] + + +def _fused_ep_moe_kernel( + # Input + tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing) + w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size) + w2_hbm, # (local_num_experts, intermediate_size, hidden_size) + # TODO(jevinjiang): We choose F32 scale for easier slicing. The extra + # latency should be hidden in the pipeline overlapping. But is there a better + # way to do this? + w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size) + w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size) + b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size) + b2_hbm, # None | F32(local_num_experts, 1, hidden_size) + gating_hbm, # (local_num_tokens, padded_num_experts) + a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing) + # Output + output_hbm, # (local_num_tokens, hidden_size) + # Scratch + t2e_routing_x2_smem, # (2, bt, padded_top_k) + d2e_count_x2_smem, # (2, num_devices, 1, padded_num_experts) + expert_offsets_x2_smem, # (2, 2, padded_num_experts): for a2a_s and a2a_g + expert_starts_x2_smem, # (2, 1, padded_num_experts) + expert_sizes_x2_smem, # (2, 1, padded_num_experts) + a2a_s_sends_x2_smem, # (2,) + a2a_s_x2_vmem, # (2, bt * num_devices, t_packing, hidden_size // t_packing) + a2a_s_acc_x2_vmem, # (2, bt * num_devices, t_packing, hidden_size // t_packing) + ### Accumulation for gathered tokens: + a2a_g_acc_vmem, # (top_k, bt, t_packing, hidden_size // t_packing) + ### Expert weight double buffering: + b_gating_x2_vmem, # (2, bt, padded_num_experts) + b_output_x2_vmem, # (2, bt, hidden_size) + b_w1_x2_vmem, # (2, t_packing, bd1 // t_packing, bf) + b_w3_x2_vmem, # (2, t_packing, bd1 // t_packing, bf) + b_w2_x2_vmem, # (2, t_packing, bf, bd2 // t_packing) + b_w1_scale_x2_vmem, # None | (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf) + b_w3_scale_x2_vmem, # None | (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf) + b_w2_scale_x2_vmem, # None | (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing) + b_b1_x2_vmem, # None | (2, 1, bf) + b_b3_x2_vmem, # None | (2, 1, bf) + b_b2_x2_vmem, # None | (2, t_packing, 1, bd2 // t_packing) + b_acc_vmem, # F32(bt * num_devices, 1, bf * 2) + ### Semaphores: + local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem] + send_sems, # (2,) + recv_sems, # (2,) + a2a_gather_sem, + a2a_acc_sem, + *, + top_k: int, + renormalize_topk_logits: bool, + ep_axis_name: str, + act_fn: str, + subc_quant_wsz: int | None = None, + # Kernel tuning params. + bt: int, # Block size of local_num_tokens. + bf: int, # Block size of intermediate_size. + bd1: int, # Block size of hidden_size in w1. + bd2: int, # Block size of hidden_size in w2. + btc: int, # Compute size of block tokens for active expert. + bfc: int, # Compute size of block intermediate_size. + bd1c: int, # Compute size of block hidden_size. + bd2c: int, # Compute size of block hidden_size. +): + my_id = lax.axis_index(ep_axis_name) + num_devices = lax.axis_size(ep_axis_name) + local_num_tokens = tokens_hbm.shape[0] + local_num_experts, intermediate_size, hidden_size = w2_hbm.shape + right_id = (my_id + 1) % num_devices + num_experts = a2a_g_hbm.shape[0] + padded_num_experts = d2e_count_x2_smem.shape[-1] + padded_top_k = t2e_routing_x2_smem.shape[-1] + assert padded_num_experts == align_to(num_experts, 128) + assert padded_top_k == align_to(top_k, 128) + + t_dtype = tokens_hbm.dtype + t_packing = get_dtype_packing(t_dtype) + t_bitwidth = 32 // t_packing + assert a2a_g_hbm.dtype == t_dtype + assert w1_hbm.dtype == w2_hbm.dtype + + assert bd1 % bd1c == 0 + assert bd2 % bd2c == 0 + assert bf % bfc == 0 + assert hidden_size % t_packing == 0 + assert bd1 % t_packing == 0 + assert bd2 % t_packing == 0 + assert bd1c % t_packing == 0 + assert bd2c % t_packing == 0 + + h_per_t_packing = hidden_size // t_packing + assert tokens_hbm.shape[-1] == h_per_t_packing + bd1_per_t_packing = bd1 // t_packing + bd2_per_t_packing = bd2 // t_packing + bd1c_per_t_packing = bd1c // t_packing + bd2c_per_t_packing = bd2c // t_packing + + if subc_quant_wsz is not None: + assert subc_quant_wsz % 256 == 0 + assert bd1c_per_t_packing == subc_quant_wsz + assert bfc == subc_quant_wsz + assert bd1 % subc_quant_wsz == 0 + assert bf % subc_quant_wsz == 0 + assert bd1_per_t_packing % subc_quant_wsz == 0 + assert h_per_t_packing % subc_quant_wsz == 0 + + num_bt = cdiv(local_num_tokens, bt) + num_bf = cdiv(intermediate_size, bf) + num_bd1 = cdiv(hidden_size, bd1) + num_bd2 = cdiv(hidden_size, bd2) + + def get_mesh_device_id(ep_rank): + dp_rank = jax.lax.axis_index("data") + return (dp_rank, ep_rank) + + def sync_barrier(): + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + device_id=get_mesh_device_id(right_id), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 1) + + def start_fetch_b_gating(bt_id, priority=0): + is_valid = jnp.logical_and(bt_id >= 0, bt_id < num_bt) + sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt) + bt_sem_id = (bt_id + 2) % 2 + b_gating_sem = local_sems.at[bt_sem_id, 0] + pltpu.make_async_copy( + src_ref=gating_hbm.at[pl.ds(bt_id * bt, sz)], + dst_ref=b_gating_x2_vmem.at[bt_sem_id, pl.ds(0, sz)], + sem=b_gating_sem, + ).start(priority=priority) + + def wait_fetch_b_gating(bt_id): + bt_sem_id = bt_id % 2 + b_gating_sem = local_sems.at[bt_sem_id, 0] + pltpu.make_async_copy( + src_ref=b_gating_x2_vmem.at[bt_sem_id], + dst_ref=b_gating_x2_vmem.at[bt_sem_id], + sem=b_gating_sem, + ).wait() + + def get_top_k(input, top_k, renormalize_topk_logits): + assert len(input.shape) == 2, input.shape + input = input.astype(jnp.float32) + padded_k_shape = (input.shape[0], padded_top_k) + top_k_logits_lst = [] + top_k_indices_lst = [] + t2e = jnp.zeros(input.shape, dtype=jnp.int32) + t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32) + iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1) + padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1) + top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32) + + for k_id in range(top_k): + # TODO(jevinjiang): return both top_k values and indices in Mosaic + top_k_logits = jnp.broadcast_to( + jnp.max(input[:, :num_experts], axis=1, keepdims=True), + padded_k_shape, + ).astype(input.dtype) + top_k_logits_lst.append(top_k_logits) + if renormalize_topk_logits: + top_k_logits_sum += top_k_logits + # TODO(jevinjiang): support bf16 argmax in Mosaic + top_k_indices = jnp.broadcast_to( + jnp.argmax(input[:, :num_experts], axis=1, keepdims=True), + padded_k_shape, + ) + top_k_indices_lst.append(top_k_indices) + t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices, t2e_routing) + mask = iota == broadcast_minor(top_k_indices, input.shape) + t2e += mask.astype(jnp.int32) + if k_id != top_k - 1: + input = jnp.where(mask, -jnp.inf, input) + + if renormalize_topk_logits: + for k_id in range(top_k): + top_k_logits_lst[k_id] /= top_k_logits_sum + + expert_sizes = jnp.sum(t2e, axis=0, keepdims=True) + expert_starts = jnp.zeros_like(expert_sizes) + return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts + + def all_reduce_metadata(bt_sem_id, t2e_routing, starts, sizes): + send_sem = send_sems.at[0] + recv_sem = recv_sems.at[0] + + # All-reduce to accumulate starts and sizes and transfer to SMEM. + def _all_reduce_metadata( + t2e_routing_vmem, + d2e_count_vmem, + offsets_vmem, + starts_vmem, + sizes_vmem, + ): + offsets_vmem[...] = jnp.zeros_like(offsets_vmem) + # TODO(jevinjiang): check how slow is VMEM -> SMEM. + offsets_copy = pltpu.async_copy( + src_ref=offsets_vmem, + dst_ref=expert_offsets_x2_smem.at[bt_sem_id], + sem=send_sem, + ) + t2e_routing_vmem[...] = t2e_routing + t2e_routing_copy = pltpu.async_copy( + src_ref=t2e_routing_vmem, + dst_ref=t2e_routing_x2_smem.at[bt_sem_id], + sem=send_sem, + ) + reduced_sizes = sizes + reduced_starts = starts + row_id = my_id + d2e_count_vmem[row_id] = sizes + for i in range(num_devices - 1): + sync_barrier() + # TODO(jevinjiang): we can use double buffering to improve AR if needed. + pltpu.async_remote_copy( + src_ref=d2e_count_vmem.at[row_id], + dst_ref=d2e_count_vmem.at[row_id], + send_sem=send_sem, + recv_sem=recv_sem, + device_id=get_mesh_device_id(right_id), + device_id_type=pltpu.DeviceIdType.MESH, + ).wait() + row_id = (row_id + num_devices - 1) % num_devices + new_sizes = d2e_count_vmem[row_id] + reduced_sizes += new_sizes + reduced_starts += lax.select(my_id > i, new_sizes, jnp.zeros_like(new_sizes)) + starts_vmem[...] = reduced_starts + sizes_vmem[...] = reduced_sizes + + starts_copy = pltpu.async_copy( + src_ref=starts_vmem, + dst_ref=expert_starts_x2_smem.at[bt_sem_id], + sem=send_sem, + ) + sizes_copy = pltpu.async_copy( + src_ref=sizes_vmem, + dst_ref=expert_sizes_x2_smem.at[bt_sem_id], + sem=send_sem, + ) + + # TODO(jevinjiang): if d2e_count is too big, we can store in HBM and fetch + # to SMEM partially. + d2e_count_copy = pltpu.async_copy( + src_ref=d2e_count_vmem, + dst_ref=d2e_count_x2_smem.at[bt_sem_id], + sem=send_sem, + ) + + t2e_routing_copy.wait() + d2e_count_copy.wait() + offsets_copy.wait() + starts_copy.wait() + sizes_copy.wait() + + pl.run_scoped( + _all_reduce_metadata, + pltpu.VMEM(t2e_routing_x2_smem.shape[1:], t2e_routing_x2_smem.dtype), + pltpu.VMEM(d2e_count_x2_smem.shape[1:], d2e_count_x2_smem.dtype), + pltpu.VMEM(expert_offsets_x2_smem.shape[1:], expert_offsets_x2_smem.dtype), + pltpu.VMEM(expert_starts_x2_smem.shape[1:], expert_starts_x2_smem.dtype), + pltpu.VMEM(expert_sizes_x2_smem.shape[1:], expert_sizes_x2_smem.dtype), + ) + + def start_a2a_scatter(bt_id, e_sem_id, local_e_id): + bt_sem_id = bt_id % 2 + + # Counting the number of remote sends from the current device. + send_sz = 0 + for bt_t_id in range(bt): + for k_id in range(top_k): + e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id] + is_active_expert = e_id % local_num_experts == local_e_id + recv_id = e_id // local_num_experts + offset = expert_offsets_x2_smem[bt_sem_id, 0, e_id] + sz = lax.select(is_active_expert, 1, 0) + is_local = recv_id == my_id + local_sz = lax.select(is_local, sz, 0) + remote_sz = lax.select(is_local, 0, sz) + send_sz += remote_sz + expert_offsets_x2_smem[bt_sem_id, 0, e_id] = offset + local_sz + remote_sz + start = expert_starts_x2_smem[bt_sem_id, 0, e_id] + offset + t_id = bt * bt_id + bt_t_id + # TODO(jevinjiang): compare the perf when using branches. + pltpu.make_async_copy( + src_ref=tokens_hbm.at[pl.ds(t_id, local_sz)], + dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(start, local_sz)], + sem=recv_sems.at[e_sem_id], + ).start() + pltpu.make_async_remote_copy( + src_ref=tokens_hbm.at[pl.ds(t_id, remote_sz)], + dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(start, remote_sz)], + send_sem=send_sems.at[e_sem_id], + recv_sem=recv_sems.at[e_sem_id], + device_id=get_mesh_device_id(recv_id), + device_id_type=pltpu.DeviceIdType.MESH, + ).start() + a2a_s_sends_x2_smem[e_sem_id] = send_sz + + def wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id): + bt_sem_id = bt_id % 2 + e_id = my_id * local_num_experts + local_e_id + sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id] + pltpu.make_async_copy( + src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)], + dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)], + sem=recv_sems.at[e_sem_id], + ).wait() + + def wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id): + del bt_id, local_e_id + sz = a2a_s_sends_x2_smem[e_sem_id] + pltpu.make_async_copy( + src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)], + dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)], + sem=send_sems.at[e_sem_id], + ).wait() + + def start_a2a_gather(bt_id, e_sem_id, local_e_id): + my_e_id = my_id * local_num_experts + local_e_id + bt_sem_id = bt_id % 2 + start = 0 + for recv_id in range(num_devices): + sz = d2e_count_x2_smem[bt_sem_id, recv_id, 0, my_e_id] + is_local = recv_id == my_id + local_sz = lax.select(is_local, sz, 0) + remote_sz = lax.select(is_local, 0, sz) + pltpu.make_async_copy( + src_ref=a2a_s_acc_x2_vmem.at[e_sem_id, pl.ds(start, local_sz)], + dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, local_sz)], + sem=a2a_gather_sem, + ).start() + pltpu.make_async_remote_copy( + src_ref=a2a_s_acc_x2_vmem.at[e_sem_id, pl.ds(start, remote_sz)], + dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)], + send_sem=send_sems.at[e_sem_id], + recv_sem=a2a_gather_sem, + device_id=get_mesh_device_id(recv_id), + device_id_type=pltpu.DeviceIdType.MESH, + ).start() + start += sz + + def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id): + my_e_id = my_id * local_num_experts + local_e_id + bt_sem_id = bt_id % 2 + sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id] + local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id] + remote_sz = sz - local_sz + is_valid = jnp.logical_and(local_e_id >= 0, local_e_id < local_num_experts) + remote_sz = lax.select(is_valid, remote_sz, 0) + pltpu.make_async_copy( + src_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)], + dst_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)], + sem=send_sems.at[e_sem_id], + ).wait() + + def wait_a2a_gather_recv_all(): + sz = top_k * bt + pltpu.make_async_copy( + src_ref=a2a_g_hbm.at[0, pl.ds(0, sz)], + dst_ref=a2a_g_hbm.at[0, pl.ds(0, sz)], + sem=a2a_gather_sem, + ).wait() + + def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id): + for p in range(t_packing): + offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing + pltpu.make_async_copy( + src_ref=w1_hbm.at[ + local_e_id, + 0, + pl.ds(offset, bd1_per_t_packing), + pl.ds(bf_id * bf, bf), + ], + dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p], + sem=local_sems.at[bw1_sem_id, 1], + ).start() + if w1_scale_hbm is not None: + assert subc_quant_wsz is not None + pltpu.make_async_copy( + src_ref=w1_scale_hbm.at[ + local_e_id, + 0, + pl.ds( + offset // subc_quant_wsz, + bd1_per_t_packing // subc_quant_wsz, + ), + pl.ds(0, 1), + pl.ds(bf_id * bf, bf), + ], + dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p], + sem=local_sems.at[bw1_sem_id, 1], + ).start() + if b1_hbm is not None and bd1_id == 0: + pltpu.make_async_copy( + src_ref=b1_hbm.at[local_e_id, 0, pl.ds(0, 1), pl.ds(bf_id * bf, bf)], + dst_ref=b_b1_x2_vmem.at[bf_id % 2], + sem=local_sems.at[bw1_sem_id, 1], + ).start() + + def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id): + for p in range(t_packing): + offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing + pltpu.make_async_copy( + src_ref=w2_hbm.at[ + local_e_id, + pl.ds(bf_id * bf, bf), + pl.ds(offset, bd2_per_t_packing), + ], + dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p], + sem=local_sems.at[bw2_sem_id, 2], + ).start() + if w2_scale_hbm is not None: + assert subc_quant_wsz is not None + pltpu.make_async_copy( + src_ref=w2_scale_hbm.at[ + local_e_id, + pl.ds(bf_id * bf // subc_quant_wsz, bf // subc_quant_wsz), + pl.ds(0, 1), + pl.ds(offset, bd2_per_t_packing), + ], + dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p], + sem=local_sems.at[bw2_sem_id, 2], + ).start() + if b2_hbm is not None and bf_id == 0: + pltpu.make_async_copy( + src_ref=b2_hbm.at[local_e_id, pl.ds(0, 1), pl.ds(offset, bd2_per_t_packing)], + dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p], + sem=local_sems.at[bw2_sem_id, 2], + ).start() + + def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id): + for p in range(t_packing): + offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing + pltpu.make_async_copy( + src_ref=w1_hbm.at[ + local_e_id, + 1, + pl.ds(offset, bd1_per_t_packing), + pl.ds(bf_id * bf, bf), + ], + dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p], + sem=local_sems.at[bw3_sem_id, 3], + ).start() + if w1_scale_hbm is not None: + assert subc_quant_wsz is not None + pltpu.make_async_copy( + src_ref=w1_scale_hbm.at[ + local_e_id, + 1, + pl.ds( + offset // subc_quant_wsz, + bd1_per_t_packing // subc_quant_wsz, + ), + pl.ds(0, 1), + pl.ds(bf_id * bf, bf), + ], + dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p], + sem=local_sems.at[bw3_sem_id, 3], + ).start() + if b1_hbm is not None and bd3_id == 0: + pltpu.make_async_copy( + src_ref=b1_hbm.at[local_e_id, 1, pl.ds(0, 1), pl.ds(bf_id * bf, bf)], + dst_ref=b_b3_x2_vmem.at[bf_id % 2], + sem=local_sems.at[bw3_sem_id, 3], + ).start() + + def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id): + del local_e_id + pltpu.make_async_copy( + src_ref=b_w1_x2_vmem.at[bw1_sem_id], + dst_ref=b_w1_x2_vmem.at[bw1_sem_id], + sem=local_sems.at[bw1_sem_id, 1], + ).wait() + if w1_scale_hbm is not None: + pltpu.make_async_copy( + src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id], + dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id], + sem=local_sems.at[bw1_sem_id, 1], + ).wait() + if b1_hbm is not None and bd1_id == 0: + pltpu.make_async_copy( + src_ref=b_b1_x2_vmem.at[bf_id % 2], + dst_ref=b_b1_x2_vmem.at[bf_id % 2], + sem=local_sems.at[bw1_sem_id, 1], + ).wait() + + def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id): + del local_e_id + pltpu.make_async_copy( + src_ref=b_w2_x2_vmem.at[bw2_sem_id], + dst_ref=b_w2_x2_vmem.at[bw2_sem_id], + sem=local_sems.at[bw2_sem_id, 2], + ).wait() + if w2_scale_hbm is not None: + pltpu.make_async_copy( + src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id], + dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id], + sem=local_sems.at[bw2_sem_id, 2], + ).wait() + if b2_hbm is not None and bf_id == 0: + pltpu.make_async_copy( + src_ref=b_b2_x2_vmem.at[bd2_id % 2], + dst_ref=b_b2_x2_vmem.at[bd2_id % 2], + sem=local_sems.at[bw2_sem_id, 2], + ).wait() + + def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id): + del local_e_id + pltpu.make_async_copy( + src_ref=b_w3_x2_vmem.at[bw3_sem_id], + dst_ref=b_w3_x2_vmem.at[bw3_sem_id], + sem=local_sems.at[bw3_sem_id, 3], + ).wait() + if w1_scale_hbm is not None: + pltpu.make_async_copy( + src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id], + dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id], + sem=local_sems.at[bw3_sem_id, 3], + ).wait() + if b1_hbm is not None and bd3_id == 0: + pltpu.make_async_copy( + src_ref=b_b3_x2_vmem.at[bf_id % 2], + dst_ref=b_b3_x2_vmem.at[bf_id % 2], + sem=local_sems.at[bw3_sem_id, 3], + ).wait() + + def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id): + next_bd1_id = bd1_id + 1 + next_bd2_id = bd2_id + 1 + next_sem_id = (bw_sem_id + 1) % 2 + + if bf_id >= num_bf: + return + if next_bd1_id < num_bd1: + start_fetch_bw1(local_e_id, next_sem_id, bf_id, next_bd1_id) + start_fetch_bw3(local_e_id, next_sem_id, bf_id, next_bd1_id) + elif next_bd1_id == num_bd1: + start_fetch_bw2(local_e_id, next_sem_id, bf_id, 0) + elif next_bd2_id < num_bd2: + start_fetch_bw2(local_e_id, next_sem_id, bf_id, next_bd2_id) + elif next_bd2_id == num_bd2: + start_fetch_next_bw(local_e_id, bw_sem_id, bf_id + 1, -1, -1) + else: + raise RuntimeError("Unreachable") + + def dynamic_ffn1( + t_b32_vmem, + w1_vmem, + w1_scale_vmem, + b1_vmem, + w3_vmem, + w3_scale_vmem, + b3_vmem, + acc1_vmem, + acc3_vmem, + dyn_sz, + should_init, + ): + assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing) + assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing, bf) + assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf) + assert bd1 % (t_packing * 128) == 0, (bd1, t_packing) + assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing) + if w1_scale_vmem is not None: + assert w1_scale_vmem.shape == ( + t_packing, + bd1_per_t_packing // subc_quant_wsz, + 1, + bf, + ) + assert bd1c_per_t_packing == subc_quant_wsz + if w3_scale_vmem is not None: + assert w3_scale_vmem.shape == ( + t_packing, + bd1_per_t_packing // subc_quant_wsz, + 1, + bf, + ) + assert bd1c_per_t_packing == subc_quant_wsz + + num_loops = cdiv(dyn_sz, btc) + repack_ty = jnp.dtype(f"int{t_bitwidth}") + + def body(btc_id, _): + for bd1c_id in range(cdiv(bd1, bd1c)): + t_b32 = t_b32_vmem[ + pl.ds(btc_id * btc, btc), + pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing), + ] + for p_id in range(t_packing): + t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype) + t_b32 = t_b32 >> t_bitwidth + for bfc_id in range(cdiv(bf, bfc)): + w_slices = ( + p_id, + pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing), + pl.ds(bfc_id * bfc, bfc), + ) + w1 = w1_vmem[*w_slices] + acc1 = jnp.dot(t, w1, preferred_element_type=jnp.float32) + + if w1_scale_vmem is not None: + w1_scale_slices = ( + p_id, + bd1c_id, + pl.ds(0, 1), + pl.ds(bfc_id * bfc, bfc), + ) + # TODO(jevinjiang): can use mosaic to load with stride 0. + w1_scale = jnp.broadcast_to(w1_scale_vmem[*w1_scale_slices], acc1.shape) + acc1 *= w1_scale + + w3 = w3_vmem[*w_slices] + + acc3 = jnp.dot(t, w3, preferred_element_type=jnp.float32) + + if w3_scale_vmem is not None: + w3_scale_slices = ( + p_id, + bd1c_id, + pl.ds(0, 1), + pl.ds(bfc_id * bfc, bfc), + ) + w3_scale = jnp.broadcast_to(w3_scale_vmem[*w3_scale_slices], acc3.shape) + acc3 *= w3_scale + + acc_slices = (pl.ds(btc_id * btc, btc), pl.ds(bfc_id * bfc, bfc)) + if should_init and p_id == bd1c_id == 0: + if b1_vmem is not None: + b1_scale_slices = ( + pl.ds(0, 1), + pl.ds(bfc_id * bfc, bfc), + ) + b1 = jnp.broadcast_to(b1_vmem[*b1_scale_slices], acc1.shape) + acc1 += b1 + if b3_vmem is not None: + b3_scale_slices = ( + pl.ds(0, 1), + pl.ds(bfc_id * bfc, bfc), + ) + b3 = jnp.broadcast_to(b3_vmem[*b3_scale_slices], acc1.shape) + acc3 += b3 + + acc1_vmem[*acc_slices] = acc1 + acc3_vmem[*acc_slices] = acc3 + else: + acc1_vmem[*acc_slices] += acc1 + acc3_vmem[*acc_slices] += acc3 + + lax.fori_loop(0, num_loops, body, None) + + def dynamic_ffn2( + acc1_vmem, + acc3_vmem, + w2_vmem, + w2_scale_vmem, + b2_vmem, + res_b32_vmem, + dyn_sz, + should_init, + ): + assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing) + assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing) + assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf) + assert bd2 % (t_packing * 128) == 0, (bd2, t_packing) + assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing) + assert t_dtype in (jnp.float32, jnp.bfloat16) + + if w2_scale_vmem is not None: + assert w2_scale_vmem.shape == ( + t_packing, + bf // subc_quant_wsz, + 1, + bd2_per_t_packing, + ) + assert bfc == subc_quant_wsz + + num_loops = cdiv(dyn_sz, btc) + assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing) + + def body(btc_id, _): + for bd2c_id in range(cdiv(bd2, bd2c)): + res_lst = [] + for p_id in range(t_packing): + res = jnp.zeros((btc, bd2c_per_t_packing), dtype=jnp.float32) + + if b2_vmem is not None and should_init: + b2_scale_slices = ( + p_id, + pl.ds(0, 1), + pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing), + ) + b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices], res.shape) + res += b2 + + for bfc_id in range(cdiv(bf, bfc)): + acc_slices = (pl.ds(btc_id * btc, btc), pl.ds(bfc_id * bfc, bfc)) + acc1 = acc1_vmem[*acc_slices] + acc3 = acc3_vmem[*acc_slices] + act = activation_fn(acc1, acc3, act_fn) + w2 = w2_vmem[ + p_id, + pl.ds(bfc_id * bfc, bfc), + pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing), + ] + acc = jnp.dot(act, w2, preferred_element_type=jnp.float32) + if w2_scale_vmem is not None: + w2_scale_slices = ( + p_id, + bfc_id, + pl.ds(0, 1), + pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing), + ) + w2_scale = jnp.broadcast_to(w2_scale_vmem[*w2_scale_slices], acc.shape) + acc *= w2_scale + res += acc + res = pltpu.bitcast(res, jnp.uint32) + if t_packing == 2: + res = res >> 16 << (16 * p_id) + else: + assert t_packing == 1 + res_lst.append(res) + res = res_lst[0] + # TODO(jevinjiang): use interleaved packing when it is exposed to Pallas + for i in range(1, t_packing): + res |= res_lst[i] + sliced_res_vmem = res_b32_vmem.at[ + pl.ds(btc_id * btc, btc), + pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing), + ] + if should_init: + sliced_res_vmem[...] = res + else: + sliced_res_vmem[...] = pltpu.bitcast( + sliced_res_vmem.bitcast(t_dtype)[...] + pltpu.bitcast(res, t_dtype), + sliced_res_vmem.dtype, + ) + + lax.fori_loop(0, num_loops, body, None) + + def expert_ffn(bt_id, e_sem_id, local_e_id): + bt_sem_id = bt_id % 2 + bw_sem_id = 0 + # start_fetch_bw1(local_e_id, bw_sem_id, 0, 0) + # start_fetch_bw3(local_e_id, bw_sem_id, 0, 0) + a2a_s_b32_vmem = ( + a2a_s_x2_vmem.bitcast(jnp.uint32) + .reshape(2, bt * num_devices, hidden_size // t_packing) + .at[e_sem_id] + ) + a2a_s_acc_b32_vmem = ( + a2a_s_acc_x2_vmem.bitcast(jnp.uint32) + .reshape(2, bt * num_devices, hidden_size // t_packing) + .at[e_sem_id] + ) + b_acc_vmem_2d = b_acc_vmem.reshape(bt * num_devices, bf * 2) + b_acc1_vmem = b_acc_vmem_2d.at[:, :bf] + b_acc3_vmem = b_acc_vmem_2d.at[:, bf:] + + e_id = my_id * local_num_experts + local_e_id + dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id] + + bd1_per_t_packing = bd1 // t_packing + bd2_per_t_packing = bd2 // t_packing + + for bf_id in range(num_bf): + for bd1_id in range(num_bd1): + start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0) + w1_scale_vmem = ( + None if b_w1_scale_x2_vmem is None else b_w1_scale_x2_vmem.at[bw_sem_id] + ) + w3_scale_vmem = ( + None if b_w3_scale_x2_vmem is None else b_w3_scale_x2_vmem.at[bw_sem_id] + ) + b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[bf_id % 2] + b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[bf_id % 2] + wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id) + wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id) + + dynamic_ffn1( + t_b32_vmem=a2a_s_b32_vmem.at[ + ..., pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing) + ], + w1_vmem=b_w1_x2_vmem.at[bw_sem_id], + w1_scale_vmem=w1_scale_vmem, + b1_vmem=b1_vmem, + w3_vmem=b_w3_x2_vmem.at[bw_sem_id], + w3_scale_vmem=w3_scale_vmem, + b3_vmem=b3_vmem, + acc1_vmem=b_acc1_vmem, + acc3_vmem=b_acc3_vmem, + dyn_sz=dyn_sz, + should_init=(bd1_id == 0), + ) + bw_sem_id = (bw_sem_id + 1) % 2 + + for bd2_id in range(num_bd2): + start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, num_bd1, bd2_id) + wait_fetch_bw2(local_e_id, bw_sem_id, bf_id, bd2_id) + if bf_id == bd2_id == 0: + wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2) + + w2_scale_vmem = ( + None if b_w2_scale_x2_vmem is None else b_w2_scale_x2_vmem.at[bw_sem_id] + ) + b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[bd2_id % 2] + dynamic_ffn2( + acc1_vmem=b_acc1_vmem, + acc3_vmem=b_acc3_vmem, + w2_vmem=b_w2_x2_vmem.at[bw_sem_id], + w2_scale_vmem=w2_scale_vmem, + b2_vmem=b2_vmem, + res_b32_vmem=a2a_s_acc_b32_vmem.at[ + ..., pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing) + ], + dyn_sz=dyn_sz, + should_init=(bf_id == 0), + ) + bw_sem_id = (bw_sem_id + 1) % 2 + + def bt_acc(bt_id, top_k_logits_lst): + bt_sem_id = bt_id % 2 + for bt_t_id in range(bt): + for k_id in range(top_k): + e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id] + offset = expert_offsets_x2_smem[bt_sem_id, 1, e_id] + expert_offsets_x2_smem[bt_sem_id, 1, e_id] = offset + 1 + pltpu.make_async_copy( + src_ref=a2a_g_hbm.at[e_id, pl.ds(offset, 1)], + dst_ref=a2a_g_acc_vmem.at[k_id, pl.ds(bt_t_id, 1)], + sem=a2a_acc_sem, + ).start() + pltpu.make_async_copy( + src_ref=a2a_g_acc_vmem, + dst_ref=a2a_g_acc_vmem, + sem=a2a_acc_sem, + ).wait() + output = None + for k_id in range(top_k): + acc = a2a_g_acc_vmem[k_id].reshape(bt, hidden_size) + logits = broadcast_minor(top_k_logits_lst[k_id], acc.shape) + acc *= logits + if output is None: + output = acc + else: + output += acc + assert output is not None + return output.astype(output_hbm.dtype) + + def start_send_bo(bt_id, priority=0): + bt_sem_id = bt_id % 2 + b_output_sem = local_sems.at[bt_sem_id, 4] + pltpu.make_async_copy( + src_ref=b_output_x2_vmem.at[bt_sem_id], + dst_ref=output_hbm.at[pl.ds(bt_id * bt, bt)], + sem=b_output_sem, + ).start(priority=priority) + + def wait_send_bo(bt_id): + is_valid = jnp.logical_and(bt_id >= 0, bt_id < num_bt) + sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt) + bt_sem_id = (bt_id + 2) % 2 + b_output_sem = local_sems.at[bt_sem_id, 4] + pltpu.make_async_copy( + src_ref=output_hbm.at[pl.ds(0, sz)], + dst_ref=output_hbm.at[pl.ds(0, sz)], + sem=b_output_sem, + ).wait() + + ### ------- Kernel start ------- ### + start_fetch_b_gating(bt_id=0) + + def run_per_bt(bt_id, e_sem_id): + bt_sem_id = bt_id % 2 + next_bt_id = bt_id + 1 + start_fetch_b_gating(next_bt_id) + wait_fetch_b_gating(bt_id) + + b_gating = b_gating_x2_vmem[bt_sem_id] + b_gating_score = jax.nn.softmax(b_gating, axis=-1) + top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k( + b_gating_score, top_k, renormalize_topk_logits + ) + + all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts, expert_sizes) + sync_barrier() + + # Start a2a scatter for first active expert. + start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0) + + def run_per_expert(local_e_id, e_sem_id): + sync_barrier() + + # Prefetch weights for CURRENT active expert. + # TODO(jevinjiang): It is hard to prefetch weights in previous iteration + # because the expert_ffn keeps overwriting the buffers. Triple buffering + # could resolve this but it takes more VMEM scratch. Need further + # experiment on this. + start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0) + start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0) + + # Next ids. + next_e_sem_id = lax.select(e_sem_id == 0, 1, 0) + next_local_e_id = local_e_id + 1 + + # Start a2a scatter for NEXT active expert. + @pl.when(next_local_e_id < local_num_experts) + def _(): + start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id) + + # Wait a2a scatter for CURRENT active expert. + wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id) + + # Perform FFN for CURRENT active expert. + expert_ffn(bt_id, e_sem_id, local_e_id) + + # Start a2a gather to send back tokens for CURRENT active expert. + start_a2a_gather(bt_id, e_sem_id, local_e_id) + + # A must-wait before next sync_barrier. + wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id) + return next_e_sem_id + + e_sem_id = lax.fori_loop(0, local_num_experts, run_per_expert, e_sem_id, unroll=False) + + # Wait to receive a2a gather for ALL experts. + wait_a2a_gather_recv_all() + + # Accumulate results for current batch. + output = bt_acc(bt_id, top_k_logits_lst) + + # Make sure it is safe to overwrite output buffer. + wait_send_bo(bt_id=bt_id - 2) + b_output_x2_vmem[bt_sem_id] = output + + start_send_bo(bt_id) + + wait_a2a_gather_send( + bt_id, + e_sem_id=e_sem_id, + local_e_id=local_num_experts - 2, + ) + wait_a2a_gather_send( + bt_id, + e_sem_id=lax.select(e_sem_id == 0, 1, 0), + local_e_id=local_num_experts - 1, + ) + return e_sem_id + + lax.fori_loop(0, num_bt, run_per_bt, 0, unroll=False) + wait_send_bo(bt_id=num_bt - 2) + wait_send_bo(bt_id=num_bt - 1) + + ### ------- Kernel end ------- ### + + +@functools.partial( + jax.jit, + static_argnames=[ + "mesh", + "top_k", + "renormalize_topk_logits", + "act_fn", + "subc_quant_wsz", + "bt", + "bf", + "bd1", + "bd2", + "btc", + "bfc", + "bd1c", + "bd2c", + "ep_axis_name", + ], +) +def fused_ep_moe( + mesh: jax.sharding.Mesh, + tokens: jax.Array, # (num_tokens, hidden_size) + w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size) + w2: jax.Array, # (num_experts, intermediate_size, hidden_size) + gating_output: jax.Array, # (num_tokens, num_experts) + top_k: int, + *, + renormalize_topk_logits: bool = False, + act_fn: str = "silu", + subc_quant_wsz: int | None = None, + w1_scale: ( + jax.Array | None + ) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size) + w2_scale: ( + jax.Array | None + ) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size) + b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size) + b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size) + # Kernel tuning parameters. + bt: int, + bf: int, + bd1: int, + bd2: int, + btc: int, + bfc: int, + bd1c: int, + bd2c: int, + ep_axis_name: str = "tensor", +): + # TODO(jevinjiang): move all these assertions to validation function. + if len(mesh.shape) != 2: + raise NotImplementedError("Only 2D mesh is supported.") + + for axis_name in mesh.axis_names: + if axis_name == ep_axis_name: + continue + if mesh.shape[axis_name] != 1: + raise NotImplementedError(f"Expected all non-ep axis to have size 1 in {mesh.shape=}") + + ep_size = mesh.shape[ep_axis_name] + num_devices = ep_size + + num_tokens, hidden_size = tokens.shape + num_experts, intermediate_size, _ = w2.shape + + if w1.shape != (num_experts, 2, hidden_size, intermediate_size): + raise ValueError( + f"Expected {w1.shape=} to be" f" {(num_experts, 2, hidden_size, intermediate_size)}." + ) + + if w2.shape != (num_experts, intermediate_size, hidden_size): + raise ValueError( + f"Expected {w2.shape=} to be" f" {(num_experts, intermediate_size, hidden_size)}." + ) + + if gating_output.shape != (num_tokens, num_experts): + raise ValueError(f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}.") + + if not (0 < top_k <= num_experts): + raise ValueError(f"Expected {top_k=} to be in range (0, {num_experts=}].") + + if hidden_size % 128 != 0 or intermediate_size % 128 != 0: + raise ValueError( + f"Expected {hidden_size=} and {intermediate_size=} to be aligned to" + " 128. Did you pad them with zeros outside the kernel?" + ) + if num_tokens % ep_size != 0: + raise ValueError(f"Expected {num_tokens=} to be aligned to {ep_size=}.") + if num_experts % ep_size != 0: + raise ValueError(f"Expected {num_experts=} to be aligned to {ep_size=}.") + + local_num_tokens = num_tokens // ep_size + # local_num_experts = num_experts // ep_size + padded_num_experts = align_to(num_experts, 128) + padded_top_k = align_to(top_k, 128) + t_dtype = tokens.dtype + t_packing = get_dtype_packing(t_dtype) + + # Override bt + if local_num_tokens <= t_packing * 8: + bt = local_num_tokens + btc = bt + bt = min(local_num_tokens, bt) + # The worst case is that all devices send bt to one device. + btc = min(bt, btc, bt * num_devices) + + if local_num_tokens % t_packing != 0: + raise ValueError(f"Expected {local_num_tokens=} to be aligned to {t_packing=}.") + + if bt % t_packing != 0: + raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.") + if local_num_tokens % bt != 0: + raise ValueError(f"Expected {local_num_tokens=} to be aligned to {bt=}.") + + if subc_quant_wsz is not None: + if subc_quant_wsz <= 0: + raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.") + if subc_quant_wsz % 256 != 0: + raise ValueError("Expected {subc_quant_wsz=} to be aligned to 256.") + if hidden_size % subc_quant_wsz != 0: + raise ValueError(f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.") + if intermediate_size % subc_quant_wsz != 0: + raise ValueError(f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}.") + # We force compute size of contracting dim to be subc_quant_wsz. So we can + # apply same scale after matmul and accumulation. + bd1c = subc_quant_wsz * t_packing + bfc = subc_quant_wsz + + if bfc % 128 != 0: + raise ValueError(f"Expected {bfc=} to be aligned to 128.") + if bd1c % (t_packing * 128) != 0: + raise ValueError(f"Expected {bd1c=} to be aligned to {t_packing * 128}.") + if bd2c % (t_packing * 128) != 0: + raise ValueError(f"Expected {bd2c=} to be aligned to {t_packing * 128}.") + if bf % bfc != 0: + raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.") + if bd1 % bd1c != 0: + raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.") + if bd2 % bd2c != 0: + raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.") + if hidden_size % bd1 != 0 or hidden_size % bd2 != 0: + raise ValueError(f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.") + if intermediate_size % bf != 0: + raise ValueError(f"Expected {intermediate_size=} to be aligned to {bf=}.") + + # Note: we should dump scale as the kernel expected shape in the + # checkpoint offline or reshape right after weight loading. + if w1_scale is not None: + expected_w1_scale_shape = ( + num_experts, + 2, + hidden_size // subc_quant_wsz, + 1, + intermediate_size, + ) + if w1_scale.shape != expected_w1_scale_shape: + raise ValueError(f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.") + if w1_scale.dtype != jnp.float32: + w1_scale = w1_scale.astype(jnp.float32) + + if w2_scale is not None: + expected_w2_scale_shape = ( + num_experts, + intermediate_size // subc_quant_wsz, + 1, + hidden_size, + ) + if w2_scale.shape != expected_w2_scale_shape: + raise ValueError(f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.") + if w2_scale.dtype != jnp.float32: + w2_scale = w2_scale.astype(jnp.float32) + + if b1 is not None: + expected_b1_shape = (num_experts, 2, 1, intermediate_size) + if b1.shape != expected_b1_shape: + raise ValueError(f"Expected {b1.shape=} to be {expected_b1_shape}.") + if b1.dtype != jnp.float32: + b1 = b1.astype(jnp.float32) + + if b2 is not None: + expected_b2_shape = (num_experts, 1, hidden_size) + if b2.shape != expected_b2_shape: + raise ValueError(f"Expected {b2.shape=} to be {expected_b2_shape}.") + if b2.dtype != jnp.float32: + b2 = b2.astype(jnp.float32) + + # Prepare inputs for the kernel. + if padded_num_experts != gating_output.shape[-1]: + gating_output = jnp.pad( + gating_output, + ((0, 0), (0, padded_num_experts - gating_output.shape[-1])), + constant_values=-jnp.inf, + ) + + tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing) + + hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM) + renorm_str = "-renorm_k" if renormalize_topk_logits else "" + scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}" + fused_moe = jax.named_scope(scope_name)( + pl.pallas_call( + functools.partial( + _fused_ep_moe_kernel, + top_k=top_k, + renormalize_topk_logits=renormalize_topk_logits, + ep_axis_name=ep_axis_name, + act_fn=act_fn, + subc_quant_wsz=subc_quant_wsz, + bt=bt, + bf=bf, + bd1=bd1, + bd2=bd2, + btc=btc, + bfc=bfc, + bd1c=bd1c, + bd2c=bd2c, + ), + out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size), t_dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + hbm_block_spec, # tokens_hbm + hbm_block_spec, # w1_hbm + hbm_block_spec, # w2_hbm + None if w1_scale is None else hbm_block_spec, # w1_scale_hbm + None if w2_scale is None else hbm_block_spec, # w2_scale_hbm + None if b1 is None else hbm_block_spec, # b1_hbm + None if b2 is None else hbm_block_spec, # b2_hbm + hbm_block_spec, # gating_output_hbm + hbm_block_spec, # a2a_g_hbm + ], + out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM), + scratch_shapes=( + [ + # t2e_routing_x2_smem + pltpu.SMEM((2, bt, padded_top_k), jnp.int32), + # d2e_count_x2_smem + pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32), + # expert_offsets_x2_smem + pltpu.SMEM((2, 2, padded_num_experts), jnp.int32), + # expert_starts_x2_smem + pltpu.SMEM((2, 1, padded_num_experts), jnp.int32), + # expert_sizes_x2_smem + pltpu.SMEM((2, 1, padded_num_experts), jnp.int32), + # a2a_s_sends_x2_smem + pltpu.SMEM((2,), jnp.int32), + # a2a_s_x2_vmem + pltpu.VMEM( + ( + 2, + bt * num_devices, + t_packing, + hidden_size // t_packing, + ), + t_dtype, + ), + # a2a_s_acc_x2_vmem + pltpu.VMEM( + ( + 2, + bt * num_devices, + t_packing, + hidden_size // t_packing, + ), + t_dtype, + ), + # a2a_g_acc_vmem + pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing), t_dtype), + # b_gating_x2_vmem + pltpu.VMEM((2, bt, padded_num_experts), t_dtype), + # b_output_x2_vmem + pltpu.VMEM((2, bt, hidden_size), t_dtype), + # b_w1_x2_vmem + pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype), + # b_w3_x2_vmem + pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype), + # b_w2_x2_vmem + pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype), + # b_w1_scale_x2_vmem + ( + None + if w1_scale is None + else pltpu.VMEM( + ( + 2, + t_packing, + bd1 // t_packing // subc_quant_wsz, + 1, + bf, + ), + jnp.float32, + ) + ), + # b_w3_scale_x2_vmem + ( + None + if w1_scale is None + else pltpu.VMEM( + ( + 2, + t_packing, + bd1 // t_packing // subc_quant_wsz, + 1, + bf, + ), + jnp.float32, + ) + ), + # b_w2_scale_x2_vmem + ( + None + if w2_scale is None + else pltpu.VMEM( + ( + 2, + t_packing, + bf // subc_quant_wsz, + 1, + bd2 // t_packing, + ), + jnp.float32, + ) + ), + # b_b1_x2_vmem + ( + None + if b1 is None + else pltpu.VMEM( + ( + 2, + 1, + bf, + ), + jnp.float32, + ) + ), + # b_b3_x2_vmem + ( + None + if b1 is None + else pltpu.VMEM( + ( + 2, + 1, + bf, + ), + jnp.float32, + ) + ), + # b_b2_x2_vmem + ( + None + if b2 is None + else pltpu.VMEM( + ( + 2, + t_packing, + 1, + bd2 // t_packing, + ), + jnp.float32, + ) + ), + # b_acc_vmem + pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32), + # local_sems + pltpu.SemaphoreType.DMA((2, 5)), + # send_sems + pltpu.SemaphoreType.DMA((2,)), + # recv_sems + pltpu.SemaphoreType.DMA((2,)), + # a2a_gather_sem + pltpu.SemaphoreType.DMA, + # a2a_acc_sem + pltpu.SemaphoreType.DMA, + ] + ), + ), + compiler_params=pltpu.CompilerParams( + collective_id=0, + vmem_limit_bytes=100 * 1024 * 1024, + ), + name=scope_name, + ) + ) + + @jax.jit + @jax.shard_map( + mesh=mesh, + in_specs=( + P(ep_axis_name), # tokens_hbm + P(ep_axis_name), # w1_hbm + P(ep_axis_name), # w2_hbm + None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm + None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm + None if b1 is None else P(ep_axis_name), # b1_hbm + None if b2 is None else P(ep_axis_name), # b2_hbm + P(ep_axis_name), # gating_output_hbm + P(), # a2a_g_hbm + ), + out_specs=P(ep_axis_name), + check_vma=False, + ) + def kernel( + tokens, + w1, + w2, + w1_scale, + w2_scale, + b1, + b2, + gating_output, + a2a_g_hbm_scratch, + ): + return fused_moe( + pltpu.with_memory_space_constraint(tokens, pltpu.HBM), # tokens_hbm + pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm + pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm + ( + None + if w1_scale is None + else pltpu.with_memory_space_constraint(w1_scale, pltpu.HBM) + ), # w1_scale_hbm + ( + None + if w2_scale is None + else pltpu.with_memory_space_constraint(w2_scale, pltpu.HBM) + ), # w2_scale_hbm + (None if b1 is None else pltpu.with_memory_space_constraint(b1, pltpu.HBM)), # b1_hbm + (None if b2 is None else pltpu.with_memory_space_constraint(b2, pltpu.HBM)), # b2_hbm + pltpu.with_memory_space_constraint(gating_output, pltpu.HBM), # gating_output_hbm + pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM), # a2a_g_hbm + ) + + a2a_g_hbm_scratch = pl.empty((num_experts, bt, t_packing, hidden_size // t_packing), t_dtype) + return kernel( + tokens, + w1, + w2, + w1_scale, + w2_scale, + b1, + b2, + gating_output, + a2a_g_hbm_scratch, + ) diff --git a/python/sgl_jax/srt/layers/fused_moe.py b/python/sgl_jax/srt/layers/fused_moe.py new file mode 100644 index 000000000..2e55233fb --- /dev/null +++ b/python/sgl_jax/srt/layers/fused_moe.py @@ -0,0 +1,219 @@ +"""Fused MoE layer using optimized TPU kernel.""" + +import jax +import jax.numpy as jnp +from flax import nnx +from jax.sharding import Mesh +from jax.sharding import PartitionSpec as P + +from sgl_jax.srt.kernels.fused_moe.v1.kernel import fused_ep_moe + + +def _get_default_tile_sizes(hidden_size: int, intermediate_size: int) -> dict[str, int]: + """ + Select appropriate tile sizes based on model dimensions. + + These values are derived from benchmarking in the test suite and optimized + for TPU performance with different model sizes. + + Args: + hidden_size: Model hidden dimension + intermediate_size: MoE intermediate (FFN) dimension + + Returns: + Dictionary containing tile size parameters for the fused kernel + """ + if hidden_size >= 4096: + # Large models (e.g., Qwen 2.5B) + return { + "bt": 64, + "bf": 768, + "bd1": 2048, + "bd2": 2048, + "btc": 64, + "bfc": 768, + "bd1c": 2048, + "bd2c": 2048, + } + elif hidden_size >= 2048: + # Medium models (e.g., Qwen 30B A3B) + return { + "bt": 16, + "bf": 384, + "bd1": 512, + "bd2": 512, + "btc": 16, + "bfc": 384, + "bd1c": 256, + "bd2c": 256, + } + else: + # Small models + return { + "bt": 32, + "bf": 512, + "bd1": 512, + "bd2": 512, + "btc": 32, + "bfc": 256, + "bd1c": 256, + "bd2c": 256, + } + + +class FusedEPMoE(nnx.Module): + """ + Expert Parallel MoE layer using fused TPU kernel. + + This layer wraps the optimized fused_ep_moe kernel which combines Top-K selection, + expert computation, and aggregation into a single efficient operation. + + Key differences from EPMoE: + - Weight format: w1 is 4D (num_experts, 2, hidden_size, intermediate_size) + where dimension 2 contains [gate_proj, up_proj] + - Input: Takes router_logits directly instead of pre-computed topk_weights/topk_ids + - Implementation: Uses Pallas kernel with manual memory management for TPU optimization + + Args: + config: Model configuration + num_experts: Total number of experts + num_experts_per_tok: Number of experts to select per token (top_k) + ep_size: Expert parallel size (number of devices to shard experts across) + mesh: JAX mesh for distributed execution + intermediate_dim: Intermediate dimension for expert FFN + weight_dtype: Data type for weights + dtype: Data type for computation + activation: Activation function ("silu", "gelu", "swigluoai") + layer_id: Layer index (for debugging) + renormalize_topk_logits: Whether to renormalize top-k weights + bt, bf, bd1, bd2, btc, bfc, bd1c, bd2c: Tile size parameters (auto-selected if None) + """ + + def __init__( + self, + config, + num_experts: int, + num_experts_per_tok: int, + ep_size: int, + mesh: Mesh, + intermediate_dim: int = 2048, + weight_dtype: jnp.dtype = jnp.bfloat16, + dtype: jnp.dtype = jnp.bfloat16, + activation: str = "silu", + layer_id: int = 0, + renormalize_topk_logits: bool = False, + # Tile size parameters - auto-selected if None + bt: int | None = None, + bf: int | None = None, + bd1: int | None = None, + bd2: int | None = None, + btc: int | None = None, + bfc: int | None = None, + bd1c: int | None = None, + bd2c: int | None = None, + ): + self.config = config + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.intermediate_dim = intermediate_dim + self.weight_dtype = weight_dtype + self.dtype = dtype + self.layer_id = layer_id + self.ep_size = ep_size + self.activation = activation + self.renormalize_topk_logits = renormalize_topk_logits + self.mesh = mesh + + if num_experts % self.ep_size != 0: + raise ValueError( + f"num_experts({num_experts}) must be divisible by ep_size ({self.ep_size})" + ) + + # Auto-select tile sizes if not provided + if any(param is None for param in [bt, bf, bd1, bd2, btc, bfc, bd1c, bd2c]): + default_sizes = _get_default_tile_sizes(config.hidden_size, intermediate_dim) + bt = bt or default_sizes["bt"] + bf = bf or default_sizes["bf"] + bd1 = bd1 or default_sizes["bd1"] + bd2 = bd2 or default_sizes["bd2"] + btc = btc or default_sizes["btc"] + bfc = bfc or default_sizes["bfc"] + bd1c = bd1c or default_sizes["bd1c"] + bd2c = bd2c or default_sizes["bd2c"] + + self.bt = bt + self.bf = bf + self.bd1 = bd1 + self.bd2 = bd2 + self.btc = btc + self.bfc = bfc + self.bd1c = bd1c + self.bd2c = bd2c + + # Initialize weights in fused format + self.w1 = nnx.Param( + jax.random.normal( + jax.random.key(0), + (num_experts, 2, config.hidden_size, intermediate_dim), + dtype=weight_dtype, + out_sharding=P("tensor", None, None, None), + ) + ) + + self.w2 = nnx.Param( + jax.random.normal( + jax.random.key(0), + (num_experts, intermediate_dim, config.hidden_size), + dtype=weight_dtype, + out_sharding=P("tensor", None, None), + ) + ) + + def __call__(self, hidden_states: jax.Array, router_logits: jax.Array) -> jax.Array: + """ + Forward pass through the fused MoE layer. + + Args: + hidden_states: Input tokens, shape (num_tokens, hidden_size) or + (batch_size, seq_len, hidden_size) + router_logits: Router output logits, shape (num_tokens, num_experts) + Note: Should be raw logits, not after softmax or top-k + + Returns: + MoE layer output, same shape as hidden_states + """ + assert hidden_states.ndim == 2 + + hidden_states = jax.sharding.reshard(hidden_states, P("tensor", None)) + router_logits = jax.sharding.reshard(router_logits, P("tensor", None)) + + output = fused_ep_moe( + mesh=self.mesh, + tokens=hidden_states, + w1=self.w1.value, + w2=self.w2.value, + gating_output=router_logits, + top_k=self.num_experts_per_tok, + renormalize_topk_logits=self.renormalize_topk_logits, + act_fn=self.activation, + # Tile sizes + bt=self.bt, + bf=self.bf, + bd1=self.bd1, + bd2=self.bd2, + btc=self.btc, + bfc=self.bfc, + bd1c=self.bd1c, + bd2c=self.bd2c, + # Optional parameters (not used in basic case) + subc_quant_wsz=None, + w1_scale=None, + w2_scale=None, + b1=None, + b2=None, + ep_axis_name="tensor", + # tp_axis_name="data", + ) + + final_output = jax.sharding.reshard(output, P(None)) + return final_output diff --git a/python/sgl_jax/srt/managers/io_struct.py b/python/sgl_jax/srt/managers/io_struct.py index b800d000d..29ee52629 100644 --- a/python/sgl_jax/srt/managers/io_struct.py +++ b/python/sgl_jax/srt/managers/io_struct.py @@ -137,7 +137,7 @@ class GenerateReqInput: batch_size: int = 1 rid: list[str] | str | None = None text: list[str] | str | None = None - input_ids: list[int] = None + input_ids: list[list[int]] | list[int] | None = None # The embeddings for input_ids; one can specify either text or input_ids or input_embeds. input_embeds: list[list[list[float]]] | list[list[float]] | None = None sampling_params: Any | None = ( diff --git a/python/sgl_jax/srt/managers/tokenizer_manager.py b/python/sgl_jax/srt/managers/tokenizer_manager.py index a380daf71..ae29bdc7f 100644 --- a/python/sgl_jax/srt/managers/tokenizer_manager.py +++ b/python/sgl_jax/srt/managers/tokenizer_manager.py @@ -555,6 +555,7 @@ async def _handle_batch_request( pass async def flush_cache(self) -> FlushCacheReqOutput: + self.auto_create_handle_loop() return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] def abort_request(self, rid: str = "", abort_all: bool = False): @@ -641,6 +642,7 @@ async def close_session( await self.send_to_scheduler.send_pyobj(obj) async def get_internal_state(self) -> list[dict[Any, Any]]: + self.auto_create_handle_loop() req = GetInternalStateReq() responses: list[GetInternalStateReqOutput] = await self.get_internal_state_communicator(req) # Many DP ranks diff --git a/python/sgl_jax/srt/managers/tp_worker.py b/python/sgl_jax/srt/managers/tp_worker.py index 37408736d..d0227e819 100644 --- a/python/sgl_jax/srt/managers/tp_worker.py +++ b/python/sgl_jax/srt/managers/tp_worker.py @@ -113,7 +113,8 @@ def __init__( else server_args.max_running_requests ) pool_limit = self.model_runner.req_to_token_pool.size - constraints = [server_limit, pool_limit, attn_backend_limit] + # constraints = [server_limit, pool_limit, attn_backend_limit] + constraints = [server_limit, pool_limit] self.max_running_requests = min(constraints) # Log each constraint for debugging logger.info("Max running requests constraints:") @@ -166,7 +167,7 @@ def __init__( ) self.precompile_bs_paddings = [] for bs in bs_padding_list: - if bs <= self.max_padded_batch_size: + if bs <= self.max_padded_batch_size and bs >= self.tp_size * 2: self.precompile_bs_paddings.append(bs) self.precompile_bs_paddings.sort() if ( diff --git a/python/sgl_jax/srt/mem_cache/radix_cache.py b/python/sgl_jax/srt/mem_cache/radix_cache.py index 7217c17f2..d187d5e11 100644 --- a/python/sgl_jax/srt/mem_cache/radix_cache.py +++ b/python/sgl_jax/srt/mem_cache/radix_cache.py @@ -82,6 +82,25 @@ def _convert_to_bigram_key(tokens: list[int]) -> list[tuple[int, int]]: return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)] +def get_child_key(key: list, page_size: int = 1): + if page_size == 1: + val = key[0] + if hasattr(val, "item"): + val = int(val) + if isinstance(val, list): + return tuple(val) + return val + + res = [] + for x in key[:page_size]: + if hasattr(x, "item"): + x = int(x) + if isinstance(x, list): + x = tuple(x) + res.append(x) + return tuple(res) + + class RadixCache(BasePrefixCache): def __init__( self, @@ -123,13 +142,11 @@ def __init__( if self.page_size == 1: self.key_match_fn = _key_match_page_size1 - self.get_child_key_fn = lambda key: (int(key[0]) if hasattr(key[0], "item") else key[0]) + self.get_child_key_fn = get_child_key else: self.key_match_fn = partial(_key_match_paged, page_size=page_size) # Ensure returning hashable types, convert numpy arrays to Python native types - self.get_child_key_fn = lambda key: tuple( - int(x) if hasattr(x, "item") else x for x in key[:page_size] - ) + self.get_child_key_fn = partial(get_child_key, page_size=page_size) self.reset() def _create_tokens_data(self, tokens: list[int]) -> np.ndarray: diff --git a/python/sgl_jax/srt/model_executor/model_runner.py b/python/sgl_jax/srt/model_executor/model_runner.py index 745671910..131a60cee 100644 --- a/python/sgl_jax/srt/model_executor/model_runner.py +++ b/python/sgl_jax/srt/model_executor/model_runner.py @@ -240,6 +240,7 @@ def load_model(self): self.model_config.configure_for_tensor_parallel(self.tp_size) self.model_config.log_kv_heads_info(self.tp_size) self.model_config.hf_config.ep_size = self.ep_size + self.model_config.hf_config.moe_backend = self.model_config.moe_backend.value self.model = self.model_loader.load_model( model_config=self.model_config, diff --git a/python/sgl_jax/srt/models/bailing_moe.py b/python/sgl_jax/srt/models/bailing_moe.py index 91d192ceb..c888e6856 100644 --- a/python/sgl_jax/srt/models/bailing_moe.py +++ b/python/sgl_jax/srt/models/bailing_moe.py @@ -259,24 +259,45 @@ def __init__( weight_dtype=router_dtype, score_func=getattr(config, "score_function", "sigmoid"), ) - self.topk = TopK( - topk=config.num_experts_per_tok, - renormalize=config.norm_topk_prob, - num_expert_group=config.n_group, - topk_group=config.topk_group, - routed_scaling_factor=config.routed_scaling_factor, - ) - self.mlp = EPMoE( - config=config, - num_experts=config.num_experts, - num_experts_per_tok=config.num_experts_per_tok, - intermediate_dim=config.moe_intermediate_size, - mesh=mesh, - ep_size=config.ep_size, - weight_dtype=dtype, - dtype=dtype, - layer_id=layer_id, - ) + + self.moe_backend = getattr(config, "moe_backend", "epmoe") + self.use_fused = self.moe_backend == "fused" + + if self.use_fused: + from sgl_jax.srt.layers.fused_moe import FusedEPMoE + + self.mlp = FusedEPMoE( + config=config, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + intermediate_dim=config.moe_intermediate_size, + mesh=mesh, + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + renormalize_topk_logits=config.norm_topk_prob, + ) + else: + self.topk = TopK( + topk=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + num_expert_group=config.n_group, + topk_group=config.topk_group, + routed_scaling_factor=config.routed_scaling_factor, + ) + self.mlp = EPMoE( + config=config, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + intermediate_dim=config.moe_intermediate_size, + mesh=mesh, + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + ) + if num_shared_experts > 0: self.shared_experts = BailingMoEMLP( hidden_size=config.hidden_size, @@ -338,9 +359,15 @@ def __call__( shared_output = None router_logits = self.moe_gate(hidden_states) - correction_bias = self.moe_gate.bias.value if self.moe_gate.bias is not None else None - topk_weights, topk_ids = self.topk(router_logits, correction_bias) - hidden_states = self.mlp(hidden_states, topk_weights, topk_ids) + if self.use_fused: + hidden_states = self.mlp(hidden_states, router_logits) + else: + correction_bias = ( + self.moe_gate.bias.value if self.moe_gate.bias is not None else None + ) + topk_weights, topk_ids = self.topk(router_logits, correction_bias) + hidden_states = self.mlp(hidden_states, topk_weights, topk_ids) + if shared_output is not None: hidden_states = hidden_states + shared_output else: @@ -569,27 +596,65 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict mappings.update(shared_experts_mappings) num_experts = getattr(self.config, "num_experts", 256) - for expert_type in ["gate_proj", "up_proj", "down_proj"]: - target_name = { - "gate_proj": "wi_0", - "up_proj": "wi_1", - "down_proj": "wo", - }[expert_type] - expert_keys = [ - f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) - ] - - if expert_type == "down_proj": - sharding = ("expert", "tensor", None) - else: - sharding = ("expert", None, "tensor") - - mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( - target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys, - sharding=sharding, + moe_backend = getattr(self.config, "moe_backend", "epmoe") + use_fused = moe_backend == "fused" + + if use_fused: + # Fused MoE Mapping + # w1: fused gate_proj(w1) + up_proj(w3) -> (num_experts, 2, hidden, intermediate) + # w2: down_proj(w2) -> (num_experts, intermediate, hidden) + + # 1. Fused w1 (gate + up) + target_path_w1 = [f"{target_prefix}.mlp.w1"] + # Add source keys for gate_proj and up_proj + for name in ["gate_proj", "up_proj"]: + target_path_w1.extend( + [f"{prefix}.mlp.experts.{i}.{name}.weight" for i in range(num_experts)] + ) + + mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping( + target_path=target_path_w1, + sharding=("tensor", None, None, None), # (E, 2, H, I) transpose=True, + concat_axis=0, + fuse_moe_weights=True, + fuse_gate_up=("gate_proj", "up_proj"), ) + # 2. w2 (down) + target_path_w2 = [f"{target_prefix}.mlp.w2"] + target_path_w2.extend( + [f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)] + ) + + mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping( + target_path=target_path_w2, + sharding=("tensor", None, None), # (E, I, H) + transpose=True, + concat_axis=-1, + ) + else: + for expert_type in ["gate_proj", "up_proj", "down_proj"]: + target_name = { + "gate_proj": "wi_0", + "up_proj": "wi_1", + "down_proj": "wo", + }[expert_type] + expert_keys = [ + f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) + ] + + if expert_type == "down_proj": + sharding = ("expert", "tensor", None) + else: + sharding = ("expert", None, "tensor") + + mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( + target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys, + sharding=sharding, + transpose=True, + ) + return mappings def __call__( diff --git a/python/sgl_jax/srt/models/grok.py b/python/sgl_jax/srt/models/grok.py index 5a3a24c0d..8b73c6c01 100644 --- a/python/sgl_jax/srt/models/grok.py +++ b/python/sgl_jax/srt/models/grok.py @@ -17,6 +17,7 @@ _yarn_find_correction_range, _yarn_get_mscale, ) +from sgl_jax.srt.layers.fused_moe import FusedEPMoE from sgl_jax.srt.layers.layernorm import RMSNorm, dual_rmsnorm_forward from sgl_jax.srt.layers.linear import LinearBase from sgl_jax.srt.layers.logits_processor import ( @@ -206,6 +207,8 @@ class Grok1MoE(nnx.Module): kernel is used for the forward pass, with outputs reduced across ranks. """ + experts: FusedEPMoE | EPMoE + def __init__( self, config: PretrainedConfig, @@ -237,18 +240,37 @@ def __init__( self.router_logit_softcapping = getattr(config, "router_logit_softcapping", 30.0) - self.experts = EPMoE( - config=config, - num_experts=num_experts, - num_experts_per_tok=self.top_k, - intermediate_dim=intermediate_size, - mesh=mesh, - activation="gelu", - ep_size=config.ep_size, - weight_dtype=dtype, - dtype=dtype, - layer_id=layer_id, - ) + # Select MoE backend based on config + self.moe_backend = getattr(config, "moe_backend", "epmoe") + self.use_fused = self.moe_backend == "fused" + + if self.use_fused: + self.experts = FusedEPMoE( + config=config, + num_experts=num_experts, + num_experts_per_tok=self.top_k, + intermediate_dim=intermediate_size, + mesh=mesh, + activation="gelu", + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + renormalize_topk_logits=False, # Match sglang behavior + ) + else: + self.experts = EPMoE( + config=config, + num_experts=num_experts, + num_experts_per_tok=self.top_k, + intermediate_dim=intermediate_size, + mesh=mesh, + activation="gelu", + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + ) def __call__(self, hidden_states: jax.Array) -> jax.Array: # Router computation with soft capping @@ -259,15 +281,22 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array: router_logits = router_logits / self.router_logit_softcapping router_logits = jax.nn.tanh(router_logits) * self.router_logit_softcapping - # Compute top-k routing weights using sglang-style approach: - # 1. Compute global softmax over ALL experts (not just top-k) - # 2. Select top-k experts based on logits - # 3. Extract corresponding weights (no renormalization) - top_k_weights, top_k_indices = self._custom_topk( - router_logits, self.top_k, renormalize=False - ) + if self.use_fused: + # Fused kernel: pass router_logits directly + # Top-K selection is handled internally by the kernel + assert isinstance(self.experts, FusedEPMoE) + return self.experts(hidden_states, router_logits) + else: + # EPMoE: compute top-k routing weights using sglang-style approach: + # 1. Compute global softmax over ALL experts (not just top-k) + # 2. Select top-k experts based on logits + # 3. Extract corresponding weights (no renormalization) + assert isinstance(self.experts, EPMoE) + top_k_weights, top_k_indices = self._custom_topk( + router_logits, self.top_k, renormalize=False + ) - return self.experts(hidden_states, top_k_weights, top_k_indices) + return self.experts(hidden_states, top_k_weights, top_k_indices) def _custom_topk( self, router_logits: jax.Array, top_k: int, renormalize: bool = False @@ -904,36 +933,89 @@ def _create_layer_mappings(self, layer_idx: int) -> dict[str, WeightMapping]: ), } - # CRITICAL: Correct MoE weight mapping - # w1 (gate_proj) -> wi_0, w3 (up_proj) -> wi_1, w2 (down_proj) -> wo - for name, target_name in [("w1", "wi_0"), ("w3", "wi_1"), ("w2", "wo")]: - target_path = [f"{target_prefix}.block_sparse_moe.experts.{target_name}"] - target_path.extend( + moe_backend = getattr(self.config, "moe_backend", "epmoe") + use_fused = moe_backend == "fused" + + if use_fused: + # Fused MoE Mapping + # w1: fused gate(w1) + up(w3) -> (num_experts, 2, hidden, intermediate) + # w2: down(w2) -> (num_experts, intermediate, hidden) + + # 1. Fused w1 (gate + up) + target_path_w1 = [f"{target_prefix}.block_sparse_moe.w1"] + # Add source keys for w1 (gate) and w3 (up) + # Note: Grok experts are 0..N-1 + for name in ["w1", "w3"]: + target_path_w1.extend( + [ + f"{prefix}.block_sparse_moe.experts.{i}.{name}.weight" + for i in range(self.config.num_local_experts) + ] + ) + + mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.w1"] = WeightMapping( + target_path=target_path_w1, + sharding=("tensor", None, None, None), # (E, 2, H, I) + transpose=True, + concat_axis=0, # concat along E axis + fuse_moe_weights=True, + fuse_gate_up=("w1", "w3"), + ) + + # 2. w2 (down) + target_path_w2 = [f"{target_prefix}.block_sparse_moe.w2"] + target_path_w2.extend( [ - f"{prefix}.block_sparse_moe.experts.{i}.{name}.weight" + f"{prefix}.block_sparse_moe.experts.{i}.w2.weight" for i in range(self.config.num_local_experts) ] ) - sharding = ( - ("expert", "tensor", None) if target_name == "wo" else ("expert", None, "tensor") + mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.w2"] = WeightMapping( + target_path=target_path_w2, + sharding=("tensor", None, None), # (E, I, H) + transpose=True, + concat_axis=-1, ) - if name == "w2": - # w2 (down_proj) -> wo: HF shape (8192, 2048), concat -> (8192, 16384), transpose -> (16384, 8192) - mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.experts.{target_name}"] = ( - WeightMapping( - target_path=target_path, sharding=sharding, transpose=True, concat_axis=-1 - ) + else: + # Standard EPMoE Mapping + for name, target_name in [("w1", "wi_0"), ("w3", "wi_1"), ("w2", "wo")]: + target_path = [f"{target_prefix}.block_sparse_moe.experts.{target_name}"] + target_path.extend( + [ + f"{prefix}.block_sparse_moe.experts.{i}.{name}.weight" + for i in range(self.config.num_local_experts) + ] ) - else: - # w1/w3 (gate/up) -> wi_0/wi_1: HF shape (2048, 8192), concat -> (16384, 8192), transpose -> (8192, 16384) - mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.experts.{target_name}"] = ( - WeightMapping( - target_path=target_path, sharding=sharding, transpose=True, concat_axis=0 - ) + + sharding = ( + ("expert", "tensor", None) + if target_name == "wo" + else ("expert", None, "tensor") ) + if name == "w2": + # w2 (down_proj) -> wo + mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.experts.{target_name}"] = ( + WeightMapping( + target_path=target_path, + sharding=sharding, + transpose=True, + concat_axis=-1, + ) + ) + else: + # w1/w3 (gate/up) -> wi_0/wi_1 + mappings[f"__MOE_EXPERTS__{prefix}.block_sparse_moe.experts.{target_name}"] = ( + WeightMapping( + target_path=target_path, + sharding=sharding, + transpose=True, + concat_axis=0, + ) + ) + return mappings diff --git a/python/sgl_jax/srt/models/qwen2_moe.py b/python/sgl_jax/srt/models/qwen2_moe.py index 0db410ef2..ba23d5b0c 100644 --- a/python/sgl_jax/srt/models/qwen2_moe.py +++ b/python/sgl_jax/srt/models/qwen2_moe.py @@ -216,10 +216,42 @@ def __init__( num_experts=num_experts, weight_dtype=dtype, ) - self.topk = TopK( - topk=num_experts_per_tok, - renormalize=getattr(config, "norm_topk_prob", True), - ) + + self.moe_backend = getattr(config, "moe_backend", "epmoe") + self.use_fused = self.moe_backend == "fused" + + if self.use_fused: + from sgl_jax.srt.layers.fused_moe import FusedEPMoE + + self.mlp = FusedEPMoE( + config=config, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + intermediate_dim=moe_intermediate_size, + mesh=mesh, + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + renormalize_topk_logits=getattr(config, "norm_topk_prob", True), + ) + else: + self.topk = TopK( + topk=num_experts_per_tok, + renormalize=getattr(config, "norm_topk_prob", True), + ) + self.mlp = EPMoE( + config=config, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + intermediate_dim=moe_intermediate_size, + mesh=mesh, + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + ) + # Optional shared expert path shared_sz = getattr(config, "shared_expert_intermediate_size", 0) if shared_sz and shared_sz > 0: @@ -243,18 +275,6 @@ def __init__( self.shared_experts = None self.shared_expert_gate = None - self.mlp = EPMoE( - config=config, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - intermediate_dim=moe_intermediate_size, - mesh=mesh, - ep_size=config.ep_size, - weight_dtype=dtype, - dtype=dtype, - layer_id=layer_id, - ) - self.input_layernorm = RMSNorm( config.hidden_size, epsilon=config.rms_norm_eps, @@ -304,8 +324,12 @@ def __call__( shared_output = None router_logits = self.moe_gate(hidden_states) - topk_weights, topk_ids = self.topk(router_logits) - mlp_output = self.mlp(hidden_states, topk_weights, topk_ids) + if self.use_fused: + mlp_output = self.mlp(hidden_states, router_logits) + else: + topk_weights, topk_ids = self.topk(router_logits) + mlp_output = self.mlp(hidden_states, topk_weights, topk_ids) + hidden_states = mlp_output if shared_output is None else (mlp_output + shared_output) return hidden_states, residual, kv_fused @@ -553,26 +577,64 @@ def _create_moe_layer_mappings(self, layer_idx: int) -> dict: mappings.update(shared_expert_mappings) num_experts = getattr(self.config, "num_experts", 8) - for expert_type in ["gate_proj", "up_proj", "down_proj"]: - target_name = { - "gate_proj": "wi_0", - "up_proj": "wi_1", - "down_proj": "wo", - }[expert_type] - expert_keys = [ - f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) - ] + moe_backend = getattr(self.config, "moe_backend", "epmoe") + use_fused = moe_backend == "fused" + + if use_fused: + # Fused MoE Mapping + # w1: fused gate_proj(w1) + up_proj(w3) -> (num_experts, 2, hidden, intermediate) + # w2: down_proj(w2) -> (num_experts, intermediate, hidden) + + # 1. Fused w1 (gate + up) + target_path_w1 = [f"{target_prefix}.mlp.w1"] + # Add source keys for gate_proj and up_proj + for name in ["gate_proj", "up_proj"]: + target_path_w1.extend( + [f"{prefix}.mlp.experts.{i}.{name}.weight" for i in range(num_experts)] + ) + + mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping( + target_path=target_path_w1, + sharding=("tensor", None, None, None), # (E, 2, H, I) + transpose=True, + concat_axis=0, + fuse_moe_weights=True, + fuse_gate_up=("gate_proj", "up_proj"), + ) - if expert_type == "down_proj": - sharding = ("expert", "tensor", None) - else: - sharding = ("expert", None, "tensor") + # 2. w2 (down) + target_path_w2 = [f"{target_prefix}.mlp.w2"] + target_path_w2.extend( + [f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts)] + ) - mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( - target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys, - sharding=sharding, + mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping( + target_path=target_path_w2, + sharding=("tensor", None, None), # (E, I, H) transpose=True, + concat_axis=-1, ) + else: + for expert_type in ["gate_proj", "up_proj", "down_proj"]: + target_name = { + "gate_proj": "wi_0", + "up_proj": "wi_1", + "down_proj": "wo", + }[expert_type] + expert_keys = [ + f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) + ] + + if expert_type == "down_proj": + sharding = ("expert", "tensor", None) + else: + sharding = ("expert", None, "tensor") + + mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( + target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys, + sharding=sharding, + transpose=True, + ) return mappings diff --git a/python/sgl_jax/srt/models/qwen3_moe.py b/python/sgl_jax/srt/models/qwen3_moe.py index 6c78effc2..2220bce97 100644 --- a/python/sgl_jax/srt/models/qwen3_moe.py +++ b/python/sgl_jax/srt/models/qwen3_moe.py @@ -8,6 +8,7 @@ from sgl_jax.srt.configs.model_config import ModelConfig from sgl_jax.srt.layers.embeddings import Embed, ParallelLMHead, RotaryEmbedding +from sgl_jax.srt.layers.fused_moe import FusedEPMoE from sgl_jax.srt.layers.layernorm import RMSNorm from sgl_jax.srt.layers.linear import LinearBase from sgl_jax.srt.layers.logits_processor import LogitsMetadata, LogitsProcessor @@ -181,25 +182,45 @@ def __init__( num_experts = getattr(config, "num_experts", 128) num_experts_per_tok = getattr(config, "num_experts_per_tok", 8) moe_intermediate_size = getattr(config, "moe_intermediate_size", 768) - self.topk = TopK( - topk=num_experts_per_tok, - renormalize=config.norm_topk_prob, - ) + + self.moe_backend = getattr(config, "moe_backend", "epmoe") + self.use_fused = self.moe_backend == "fused" + self.moe_gate = GateLogit( input_size=config.hidden_size, num_experts=num_experts, ) - self.mlp = EPMoE( - config=config, - num_experts=num_experts, - num_experts_per_tok=num_experts_per_tok, - intermediate_dim=moe_intermediate_size, - mesh=mesh, - ep_size=config.ep_size, - weight_dtype=dtype, - dtype=dtype, - layer_id=layer_id, - ) + + if self.use_fused: + self.mlp = FusedEPMoE( + config=config, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + intermediate_dim=moe_intermediate_size, + mesh=mesh, + activation="silu", + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + renormalize_topk_logits=config.norm_topk_prob, + ) + else: + self.topk = TopK( + topk=num_experts_per_tok, + renormalize=config.norm_topk_prob, + ) + self.mlp = EPMoE( + config=config, + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + intermediate_dim=moe_intermediate_size, + mesh=mesh, + ep_size=config.ep_size, + weight_dtype=dtype, + dtype=dtype, + layer_id=layer_id, + ) self.is_moe_layer = True self.input_layernorm = RMSNorm( @@ -242,8 +263,12 @@ def __call__( if self.is_moe_layer: router_logits = self.moe_gate(hidden_states) - topk_weights, topk_ids = self.topk(router_logits) - hidden_states = self.mlp(hidden_states, topk_weights, topk_ids) + + if self.use_fused: + hidden_states = self.mlp(hidden_states, router_logits) + else: + topk_weights, topk_ids = self.topk(router_logits) + hidden_states = self.mlp(hidden_states, topk_weights, topk_ids) else: hidden_states = self.mlp(hidden_states) @@ -478,47 +503,57 @@ def _create_moe_layer_mappings(self, layer_idx: int, is_mlp_layer: bool) -> dict transpose=True, ) + moe_backend = getattr(self.config, "moe_backend", "epmoe") + use_fused = moe_backend == "fused" num_experts = getattr(self.config, "num_experts", 128) - for expert_type in ["gate_proj", "up_proj", "down_proj"]: - target_name = { - "gate_proj": "wi_0", - "up_proj": "wi_1", - "down_proj": "wo", - }[expert_type] - - expert_keys = [ - f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) - ] - if expert_type == "down_proj": - sharding = ("expert", "tensor", None) - else: - sharding = ("expert", None, "tensor") - # world_size = ( - # self.mesh.shape.get("data", 1) - # * self.mesh.shape.get("tensor", 1) - # * self.mesh.shape.get("expert", 1) - # ) - # tp_size = world_size // self.config.ep_size - - # if self.config.ep_size == 1: - # # TP - # if expert_type == "down_proj": - # sharding = (None, ("data", "tensor"), None) - # else: - # sharding = (None, None, ("data", "tensor")) - # elif tp_size > 1: - # # ETP - - # else: - # # EP - # sharding = (("data", "tensor"), None, None) - - mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( - target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys, - sharding=sharding, + if use_fused: + # Fused MoE Mapping + # w1: fused gate_proj(w1) + up_proj(w3) -> (num_experts, 2, hidden, intermediate) + # w2: down_proj(w2) -> (num_experts, intermediate, hidden) + w1_expert_keys = [] + for expert_type in ["gate_proj", "up_proj"]: + w1_expert_keys = w1_expert_keys + [ + f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) + ] + mappings[f"__MOE_EXPERTS__{prefix}.mlp.w1"] = WeightMapping( + target_path=[f"{target_prefix}.mlp.w1"] + w1_expert_keys, + sharding=("tensor", None, None, None), # (E, 2, H, I) + transpose=True, + fuse_moe_weights=True, + fuse_gate_up=("gate_proj", "up_proj"), + ) + w2_expert_keys = [ + f"{prefix}.mlp.experts.{i}.down_proj.weight" for i in range(num_experts) + ] + mappings[f"__MOE_EXPERTS__{prefix}.mlp.w2"] = WeightMapping( + target_path=[f"{target_prefix}.mlp.w2"] + w2_expert_keys, + sharding=("tensor", None, None), # (E, I, H) transpose=True, ) + else: + # EPMoE mapping - always use expert sharding + for expert_type in ["gate_proj", "up_proj", "down_proj"]: + target_name = { + "gate_proj": "wi_0", + "up_proj": "wi_1", + "down_proj": "wo", + }[expert_type] + + expert_keys = [ + f"{prefix}.mlp.experts.{i}.{expert_type}.weight" for i in range(num_experts) + ] + + if expert_type == "down_proj": + sharding = ("expert", "tensor", None) + else: + sharding = ("expert", None, "tensor") + + mappings[f"__MOE_EXPERTS__{prefix}.mlp.{target_name}"] = WeightMapping( + target_path=[f"{target_prefix}.mlp.{target_name}"] + expert_keys, + sharding=sharding, + transpose=True, + ) return mappings diff --git a/python/sgl_jax/srt/server_args.py b/python/sgl_jax/srt/server_args.py index 476eb157d..cb1d990d8 100644 --- a/python/sgl_jax/srt/server_args.py +++ b/python/sgl_jax/srt/server_args.py @@ -122,6 +122,7 @@ class ServerArgs: # Kernel backend attention_backend: str | None = "fa" + moe_backend: str = "epmoe" grammar_backend: str | None = None @@ -796,6 +797,13 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) + parser.add_argument( + "--moe-backend", + type=str, + choices=["epmoe", "fused", "auto"], + default=ServerArgs.moe_backend, + help="The backend to use for MoE models.", + ) parser.add_argument( "--enable-nan-detection", diff --git a/python/sgl_jax/srt/utils/weight_utils.py b/python/sgl_jax/srt/utils/weight_utils.py index bd509635c..6d7347a57 100644 --- a/python/sgl_jax/srt/utils/weight_utils.py +++ b/python/sgl_jax/srt/utils/weight_utils.py @@ -28,6 +28,9 @@ class WeightMapping: kv_head_padding: bool = False concat_axis: int | None = None is_eagle3: bool = False + # MoE weight fusion configuration + fuse_moe_weights: bool = False + fuse_gate_up: tuple[str, str] | None = None def __post_init__(self): if self.sharding is None: @@ -159,36 +162,92 @@ def load_weights_from_safetensors( for moe_key, mapping in moe_mappings.items(): expected_hf_keys = mapping.target_path[1:] # list of expected HF keys if hf_key in expected_hf_keys: - if moe_key not in moe_buffer: - moe_buffer[moe_key] = {} - if hf_key not in moe_buffer[moe_key]: - moe_buffer[moe_key][hf_key] = [] - moe_buffer[moe_key][hf_key].append(hf_weight) - assigned = True - - if len(moe_buffer[moe_key]) == len(expected_hf_keys): - shard_counts = [len(v) for v in moe_buffer[moe_key].values()] - # Validate all weights have consistent shard counts - if len(set(shard_counts)) != 1: + if mapping.fuse_moe_weights: + # Fused MoE Logic + gate_id, up_id = mapping.fuse_gate_up + if gate_id in hf_key: + group_type = "gate" + elif up_id in hf_key: + group_type = "up" + else: + logger.warning( + "Fused key %s matches neither %s nor %s", + hf_key, + gate_id, + up_id, + ) continue - # Auto-detect TP sharding: - # - Grok-2: concat_axis is set, needs multiple shards (e.g., 8) - if mapping.concat_axis is not None: - # TP-sharded weights: need to collect all TP shards - # Expected number of shards = total model files / experts per file - if shard_counts[0] < safetensors_partition: - # Still collecting shards, wait for more + if moe_key not in moe_buffer: + moe_buffer[moe_key] = {"gate": {}, "up": {}} + + if hf_key not in moe_buffer[moe_key][group_type]: + moe_buffer[moe_key][group_type][hf_key] = [] + + moe_buffer[moe_key][group_type][hf_key].append(hf_weight) + assigned = True + + # Check if we have all necessary weights + total_captured = len(moe_buffer[moe_key]["gate"]) + len( + moe_buffer[moe_key]["up"] + ) + + if total_captured == len(expected_hf_keys): + # Validate shard counts for ALL weights + all_shard_counts = [] + for g_type in ["gate", "up"]: + for w_list in moe_buffer[moe_key][g_type].values(): + all_shard_counts.append(len(w_list)) + + if not all_shard_counts: # Should not happen if total_captured > 0 continue - else: - # Non-TP-sharded weights: expect exactly 1 copy per expert - if shard_counts[0] != 1: + + if len(set(all_shard_counts)) != 1: continue - self._process_single_moe_group( - params, moe_key, mapping, moe_buffer[moe_key] - ) - del moe_buffer[moe_key] # free memory + if mapping.concat_axis is not None: + if all_shard_counts[0] < safetensors_partition: + continue + elif all_shard_counts[0] != 1: + continue + + self._process_fused_moe_group( + params, moe_key, mapping, moe_buffer[moe_key] + ) + del moe_buffer[moe_key] + + else: + # Regular (Non-Fused) MoE Logic + if moe_key not in moe_buffer: + moe_buffer[moe_key] = {} + if hf_key not in moe_buffer[moe_key]: + moe_buffer[moe_key][hf_key] = [] + moe_buffer[moe_key][hf_key].append(hf_weight) + assigned = True + + if len(moe_buffer[moe_key]) == len(expected_hf_keys): + shard_counts = [len(v) for v in moe_buffer[moe_key].values()] + # Validate all weights have consistent shard counts + if len(set(shard_counts)) != 1: + continue + + # Auto-detect TP sharding: + # - Grok-2: concat_axis is set, needs multiple shards (e.g., 8) + if mapping.concat_axis is not None: + # TP-sharded weights: need to collect all TP shards + # Expected number of shards = total model files / experts per file + if shard_counts[0] < safetensors_partition: + # Still collecting shards, wait for more + continue + else: + # Non-TP-sharded weights: expect exactly 1 copy per expert + if shard_counts[0] != 1: + continue + + self._process_single_moe_group( + params, moe_key, mapping, moe_buffer[moe_key] + ) + del moe_buffer[moe_key] # free memory break if not assigned: @@ -203,10 +262,24 @@ def load_weights_from_safetensors( for moe_key in moe_buffer: mapping = moe_mappings[moe_key] expected = len(mapping.target_path[1:]) - got = len(moe_buffer[moe_key]) - shard_counts = ( - [len(v) for v in moe_buffer[moe_key].values()] if moe_buffer[moe_key] else [] - ) + + if mapping.fuse_moe_weights: + got_gate = len(moe_buffer[moe_key].get("gate", {})) + got_up = len(moe_buffer[moe_key].get("up", {})) + got = got_gate + got_up + shard_counts = [] + if "gate" in moe_buffer[moe_key]: + shard_counts.extend([len(v) for v in moe_buffer[moe_key]["gate"].values()]) + if "up" in moe_buffer[moe_key]: + shard_counts.extend([len(v) for v in moe_buffer[moe_key]["up"].values()]) + else: + got = len(moe_buffer[moe_key]) + shard_counts = ( + [len(v) for v in moe_buffer[moe_key].values()] + if moe_buffer[moe_key] + else [] + ) + logger.error( "MoE group %s incomplete: %s/%s weights loaded, shard_counts=%s, concat_axis=%s", moe_key, @@ -219,6 +292,9 @@ def load_weights_from_safetensors( nnx.update(self.model, params) + # Final verification: check all fused MoE layers + self._verify_fused_moe_weights(params, moe_mappings) + def _process_single_moe_group( self, params: nnx.State, @@ -264,6 +340,116 @@ def _process_single_moe_group( logger.debug("Assigned MoE group %s, shape: %s", moe_key, stacked_weight.shape) + def _process_fused_moe_group( + self, + params: nnx.State, + moe_key: str, + mapping: WeightMapping, + grouped_weights: dict[str, dict[str, list[jax.Array]]], + ): + """ + Process fused MoE weight groups (gate + up weights). + + Args: + params: Model parameter state + moe_key: MoE weight key (e.g., "__MOE_EXPERTS__model.layers.0.block_sparse_moe.experts.w1") + mapping: Weight mapping configuration + grouped_weights: Grouped weights dict + { + "gate": {hf_key: [weight_shard1, weight_shard2, ...]}, + "up": {hf_key: [weight_shard1, weight_shard2, ...]} + } + """ + target_path = mapping.target_path[0] + expected_hf_keys = mapping.target_path[1:] + + # Step 1: Process gate and up weights separately + # Use the predefined order from expected_hf_keys, not sorting + gate_weights = [] + up_weights = [] + + gate_id, up_id = mapping.fuse_gate_up + + # Separate expected keys into gate and up based on fuse_gate_up config + for hf_key in expected_hf_keys: + if gate_id in hf_key: + # This is a gate weight + weights = grouped_weights["gate"][hf_key] + + # Concatenate TP shards + if mapping.concat_axis is not None and len(weights) > 1: + weight = jnp.concatenate(weights, axis=mapping.concat_axis) + else: + weight = weights[0] + + # Transpose + if mapping.transpose: + weight = jnp.transpose(weight, (1, 0)) + + gate_weights.append(weight) + + elif up_id in hf_key: + # This is an up weight + weights = grouped_weights["up"][hf_key] + + # Concatenate TP shards + if mapping.concat_axis is not None and len(weights) > 1: + weight = jnp.concatenate(weights, axis=mapping.concat_axis) + else: + weight = weights[0] + + # Transpose + if mapping.transpose: + weight = jnp.transpose(weight, (1, 0)) + + up_weights.append(weight) + + # Step 2: Stack to 3D tensors + # gate_stacked: (num_experts, hidden_size, intermediate_size) + # up_stacked: (num_experts, hidden_size, intermediate_size) + gate_stacked = jnp.stack(gate_weights, axis=0) + up_stacked = jnp.stack(up_weights, axis=0) + + # Step 3: Fuse to 4D tensor + # fused_weight: (num_experts, 2, hidden_size, intermediate_size) + fused_weight = jnp.stack([gate_stacked, up_stacked], axis=1) + + # Step 4: Apply sharding + if "expert" in mapping.sharding: + ep_size = getattr(self.model_config.hf_config, "ep_size", 1) + world_size = self.mesh.shape.get("data", 1) * self.mesh.shape.get("tensor", 1) + tp_size = world_size // ep_size + + devices = self.mesh.devices.flatten() + moe_mesh = jax.sharding.Mesh( + devices.reshape(ep_size, tp_size), axis_names=("expert", "tensor") + ) + + sharded_weight = self._shard_weight(fused_weight, mapping.sharding, mesh=moe_mesh) + else: + sharded_weight = self._shard_weight(fused_weight, mapping.sharding) + + # Step 5: Assign to model parameter + model_param = self._get_param(params, target_path) + original_dtype = model_param.value.dtype + expected_shape = model_param.value.shape + + # Validate shape before assignment + if fused_weight.shape != expected_shape: + raise ValueError( + f"Fused MoE weight shape mismatch for {target_path}: " + f"expected {expected_shape}, got {fused_weight.shape}" + ) + + model_param.value = sharded_weight.astype(original_dtype) + + # Verify assignment was successful + actual_shape = model_param.value.shape + if actual_shape != expected_shape: + raise RuntimeError( + f"Failed to assign fused MoE weight to {target_path}: shape mismatch" + ) + def _load_dummy_weights( self, params: nnx.State, @@ -838,3 +1024,72 @@ def _is_excluded_layer_weight(self, hf_key: str) -> bool: layer_num = int(parts[2]) return layer_num >= self.model_config.num_hidden_layers + + def _verify_fused_moe_weights( + self, params: nnx.State, moe_mappings: dict[str, WeightMapping] + ) -> None: + """Verify that all fused MoE weights were loaded correctly.""" + # Get all fused w1 mappings + fused_w1_mappings = { + k: v for k, v in moe_mappings.items() if getattr(v, "fuse_moe_weights", False) + } + + # Get corresponding w2 mappings (same layer, but w2 instead of w1) + w2_mappings = {} + for k in fused_w1_mappings: + w2_key = k.replace(".w1", ".w2") + if w2_key in moe_mappings: + w2_mappings[w2_key] = moe_mappings[w2_key] + + if not fused_w1_mappings: + return + + all_verified = True + verified_count = 0 + + # Verify w1 and w2 weights + for _, mapping in fused_w1_mappings.items(): + target_path = mapping.target_path[0] + try: + model_param = self._get_param(params, target_path) + weight_shape = model_param.value.shape + weight_values = model_param.value + + if ( + len(weight_shape) != 4 + or weight_shape[1] != 2 + or jnp.all(weight_values == 0) + or jnp.any(jnp.isnan(weight_values)) + ): + logger.error("✗ %s: Invalid or corrupted weights", target_path) + all_verified = False + else: + verified_count += 1 + except (KeyError, AttributeError, ValueError) as e: + logger.error("✗ %s: Failed to access - %s", target_path, str(e)) + all_verified = False + + for _, mapping in w2_mappings.items(): + target_path = mapping.target_path[0] + try: + model_param = self._get_param(params, target_path) + weight_shape = model_param.value.shape + weight_values = model_param.value + + if ( + len(weight_shape) != 3 + or jnp.all(weight_values == 0) + or jnp.any(jnp.isnan(weight_values)) + ): + logger.error("✗ %s (w2): Invalid or corrupted weights", target_path) + all_verified = False + else: + verified_count += 1 + except (KeyError, AttributeError, ValueError) as e: + logger.error("✗ %s (w2): Failed to access - %s", target_path, str(e)) + all_verified = False + + if all_verified: + logger.info("✓ Fused MoE weights verified: %d layers", verified_count // 2) + else: + raise RuntimeError("Fused MoE weight verification failed") diff --git a/python/sgl_jax/test/kernels/fused_moe_v1_test.py b/python/sgl_jax/test/kernels/fused_moe_v1_test.py new file mode 100644 index 000000000..17fdbd302 --- /dev/null +++ b/python/sgl_jax/test/kernels/fused_moe_v1_test.py @@ -0,0 +1,389 @@ +# Adapted from https://github.com/vllm-project/tpu-inference/blob/main/tests/kernels/fused_moe_v1_test.py +# Copyright 2025 The tpu-inference Authors. All rights reserved. +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest, parameterized +from jax._src import test_util as jtu +from jax.sharding import Mesh + +from sgl_jax.srt.kernels.fused_moe.v1.kernel import fused_ep_moe, ref_moe + +jax.config.parse_flags_with_absl() + + +def cdiv(a, b): + assert b != 0 + return (a + b - 1) // b + + +def align_to(x, a): + return cdiv(x, a) * a + + +def gen_moe_inputs( + dtype, + top_k, + num_experts, + hidden_size, + intermediate_size, + num_tokens, + *, + seed=1234, + has_bias=False, +): + key = jax.random.key(seed) + k0, k1, k2, k3, k4, k5, k6 = jax.random.split(key, 7) + + a = jax.random.normal(k0, (num_tokens, hidden_size), dtype=jnp.float32).astype(dtype) / 10 + + w1 = ( + jax.random.normal( + k1, + (num_experts, 2, hidden_size, intermediate_size), + dtype=jnp.float32, + ) + / 10 + ).astype(dtype) + w2 = ( + jax.random.normal(k2, (num_experts, intermediate_size, hidden_size), dtype=jnp.float32) / 10 + ).astype(dtype) + + if has_bias: + b1 = ( + jax.random.normal(k3, (num_experts, 2, 1, intermediate_size), dtype=jnp.float32) / 10 + ).astype(dtype) + b2 = (jax.random.normal(k4, (num_experts, 1, hidden_size), dtype=jnp.float32) / 10).astype( + dtype + ) + else: + b1 = b2 = None + + gating_output = ( + jax.random.normal(k5, (num_tokens, num_experts), dtype=jnp.float32) + + jnp.arange(num_tokens * num_experts, dtype=jnp.float32).reshape(num_tokens, num_experts) + / 100 + ) + + # To generate unique top-k! + top_k_indices = jax.random.randint( + k6, (num_tokens, top_k), minval=0, maxval=num_experts - 1, dtype=jnp.int32 + ) + + one_hot = ( + jnp.sum( + jax.nn.one_hot(top_k_indices, num_experts, dtype=jnp.float32), + axis=1, + ) + * 30 + ) + + gating_output = (gating_output + one_hot).astype(dtype) + + return a, w1, w2, b1, b2, gating_output + + +def sub_channel_quantize(x, quant_dtype, wsz=256): + """Quantizes x with sub-channel quantization on the 2nd minor.""" + if jnp.issubdtype(quant_dtype, jnp.floating): + dtype_info = jnp.finfo(quant_dtype) + else: + dtype_info = jnp.iinfo(quant_dtype) + dtype_max = float(dtype_info.max) + w_lst, scale_lst = [], [] + assert len(x.shape) >= 2 + assert x.shape[-2] % wsz == 0 + for i in range(0, x.shape[-2], wsz): + y = x[..., i : i + wsz, :] + abs_max = jnp.abs(y).max(axis=-2, keepdims=True) + scale = (abs_max / dtype_max).astype(jnp.float32) + w = (y / scale).astype(quant_dtype) + w_lst.append(w) + scale_lst.append(scale) + return jnp.concat(w_lst, axis=-2), jnp.expand_dims(jnp.concat(scale_lst, axis=-2), axis=-2) + + +@jtu.with_config(jax_numpy_dtype_promotion="standard") +class MoEKernelTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + self.mesh_devices = sorted( + jax.devices(), + key=lambda x: ( + x.coords[0], + (-1 if x.coords[0] % 2 else 1) * x.coords[1], + ), + ) + self.mesh = Mesh(np.array(self.mesh_devices).reshape(-1, 1), axis_names=("tensor", "data")) + + def _test_moe( + self, + dtype, + top_k, + num_experts, + hidden_size, + intermediate_size, + num_tokens, + seed, + renormalize_topk_logits, + bt, + bf, + bd1, + bd2, + btc, + bfc, + bd1c, + bd2c, + act_fn="silu", + w_dtype=None, + subc_quant_wsz=None, + has_bias=False, + atol=2e-1, + rtol=2e-1, + ): + a, w1, w2, b1, b2, gating_output = gen_moe_inputs( + dtype, + top_k, + num_experts, + hidden_size, + intermediate_size, + num_tokens, + seed=seed, + has_bias=has_bias, + ) + w1_scale = None + w2_scale = None + if w_dtype is not None: + if subc_quant_wsz is None: + subc_quant_wsz = 256 + w1, w1_scale = sub_channel_quantize(w1, w_dtype, subc_quant_wsz) + w2, w2_scale = sub_channel_quantize(w2, w_dtype, subc_quant_wsz) + + actual = fused_ep_moe( + mesh=self.mesh, + tokens=a, + w1=w1, + w2=w2, + gating_output=gating_output, + top_k=top_k, + renormalize_topk_logits=renormalize_topk_logits, + act_fn=act_fn, + subc_quant_wsz=subc_quant_wsz, + w1_scale=w1_scale, + w2_scale=w2_scale, + b1=b1, + b2=b2, + bt=bt, + bf=bf, + bd1=bd1, + bd2=bd2, + btc=btc, + bfc=bfc, + bd1c=bd1c, + bd2c=bd2c, + ep_axis_name="tensor", + ) + expected = ref_moe( + a, + w1, + w2, + gating_output, + top_k, + b1=b1, + b2=b2, + renormalize_topk_logits=renormalize_topk_logits, + act_fn=act_fn, + subc_quant_wsz=subc_quant_wsz, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) + self.assertAllClose(actual, expected, atol=atol, rtol=rtol) + + @parameterized.product( + renormalize_topk_logits=[True, False], + ) + def test_basic(self, renormalize_topk_logits): + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=renormalize_topk_logits, + bt=32, + bf=1024, + bd1=1024, + bd2=1024, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, + ) + + @parameterized.product( + act_fn=["silu", "gelu", "swigluoai"], + ) + def test_activation(self, act_fn): + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=True, + act_fn=act_fn, + bt=32, + bf=512, + bd1=512, + bd2=512, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, + ) + + def test_benchmark_qwen_235(self): + num_experts = 128 + top_k = 8 + hidden_size = 4096 + intermediate_size = 1536 + dtype = jnp.bfloat16 + num_tokens = 8 * 64 + seed = 54321 + renormalize_topk_logits = True + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=seed, + renormalize_topk_logits=renormalize_topk_logits, + bt=64, + bf=768, + bd1=2048, + bd2=2048, + btc=64, + bfc=768, + bd1c=2048, + bd2c=2048, + act_fn="silu", + atol=5e-2, + rtol=5e-2, + ) + + def test_benchmark_qwen_30b_a3b(self): + num_experts = 128 + top_k = 8 + hidden_size = 2048 + intermediate_size = 768 + dtype = jnp.bfloat16 + num_tokens = 512 + seed = 54321 + renormalize_topk_logits = True + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=seed, + renormalize_topk_logits=renormalize_topk_logits, + bt=16, + bf=384, + bd1=512, + bd2=512, + btc=16, + bfc=384, + bd1c=256, + bd2c=256, + act_fn="silu", + atol=5e-2, + rtol=5e-2, + ) + + @parameterized.product( + w_dtype=[jnp.int8, jnp.float8_e5m2, jnp.float4_e2m1fn], + ) + def test_sub_channel_quantization(self, w_dtype): + if w_dtype in ( + jnp.float8_e5m2, + jnp.float4_e2m1fn, + ) and not jtu.is_device_tpu_at_least(version=7): + self.skipTest("Expect TPUv7+") + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=False, + w_dtype=w_dtype, + subc_quant_wsz=256, + bt=32, + bf=1024, + bd1=1024, + bd2=1024, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, + ) + + def test_bias(self): + dtype = jnp.bfloat16 + top_k = 8 + num_experts = 128 + hidden_size = 1024 + intermediate_size = 1024 + num_tokens = 8 * 32 + self._test_moe( + dtype=dtype, + top_k=top_k, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_tokens=num_tokens, + seed=1234, + renormalize_topk_logits=False, + has_bias=True, + bt=32, + bf=512, + bd1=512, + bd2=512, + btc=32, + bfc=256, + bd1c=256, + bd2c=256, + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index eca24fc45..02df5e9b0 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -160,6 +160,7 @@ def run_one_file(filename): "unit-test-tpu-v6e-1": [ TestFile("python/sgl_jax/test/test_flashattention.py", 20), TestFile("python/sgl_jax/test/test_moe_topk.py", 1), + TestFile("python/sgl_jax/test/kernels/fused_moe_v1_test.py", 10), TestFile("python/sgl_jax/test/test_sampler.py", 0.2), TestFile("python/sgl_jax/test/test_utils.py", 0.2), TestFile("python/sgl_jax/test/mem_cache/test_kv_cache.py", 20), diff --git a/test/srt/test_qwen3_moe_models.py b/test/srt/test_qwen3_moe_models.py index 9c47ef930..d1874b81d 100644 --- a/test/srt/test_qwen3_moe_models.py +++ b/test/srt/test_qwen3_moe_models.py @@ -109,6 +109,10 @@ def setUpClass(cls): "16", "--page-size", "64", + "--ep-size", + "4", + "--moe-backend", + "fused", ], env={ "JAX_COMPILATION_CACHE_DIR": "/tmp/jax_compilation_cache",