Skip to content

Commit 43f8198

Browse files
committed
add fused moe benchmark
1 parent 9f1babb commit 43f8198

File tree

8 files changed

+1891
-0
lines changed

8 files changed

+1891
-0
lines changed

benchmark/fused_moe/README.md

Lines changed: 428 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
"""Benchmark script comparing FusedEPMoE vs EPMoE implementations.
2+
3+
This script performs layer-level benchmarking with synthetic weights and controlled
4+
token distribution scenarios (random, balanced, imbalanced).
5+
6+
Example usage:
7+
# Quick test
8+
python benchmark/fused_moe/bench_fused_vs_epmoe.py \
9+
--num-experts 8 --num-experts-per-tok 2 \
10+
--hidden-size 1024 --intermediate-size 4096 \
11+
--num-tokens 512 --scenarios random
12+
13+
# Using HF model config
14+
python benchmark/fused_moe/bench_fused_vs_epmoe.py \
15+
--model-path Qwen/Qwen2.5-MoE-A2.7B \
16+
--ep-size 8 --num-tokens 1024 2048 4096 \
17+
--scenarios random balanced imbalanced \
18+
--profile --profile-dir ./profiles/qwen
19+
20+
# 4 GPUs with expert parallelism (tp=4, ep=4, tp_actual=1)
21+
python benchmark/fused_moe/bench_fused_vs_epmoe.py \
22+
--model-path Qwen/Qwen2.5-MoE-A2.7B \
23+
--tp-size 4 --ep-size 4
24+
25+
# 8 GPUs with expert and tensor parallelism (tp=8, ep=4, tp_actual=2)
26+
python benchmark/fused_moe/bench_fused_vs_epmoe.py \
27+
--model-path Qwen/Qwen2.5-MoE-A2.7B \
28+
--tp-size 8 --ep-size 4
29+
"""
30+
31+
import argparse
32+
import os
33+
import sys
34+
35+
import jax
36+
37+
# Add python directory to path for imports
38+
benchmark_dir = os.path.dirname(os.path.abspath(__file__)) # benchmark/fused_moe
39+
benchmark_root = os.path.dirname(benchmark_dir) # benchmark
40+
project_root = os.path.dirname(benchmark_root) # sgl-jax
41+
python_dir = os.path.join(project_root, "python") # sgl-jax/python
42+
sys.path.insert(0, python_dir)
43+
sys.path.insert(0, project_root) # For benchmark imports
44+
45+
from benchmark.fused_moe.benchmark_runner import MoEBenchmarkRunner # noqa: E402
46+
from benchmark.fused_moe.config_utils import MoEBenchmarkConfig # noqa: E402
47+
from benchmark.fused_moe.output_formatter import save_results # noqa: E402
48+
from benchmark.fused_moe.synthetic_data import create_synthetic_weights # noqa: E402
49+
50+
51+
def parse_args() -> argparse.Namespace:
52+
"""Parse command-line arguments."""
53+
parser = argparse.ArgumentParser(
54+
description="Benchmark FusedEPMoE vs EPMoE implementations",
55+
formatter_class=argparse.RawDescriptionHelpFormatter,
56+
)
57+
58+
# Model configuration (mutually exclusive)
59+
config_group = parser.add_mutually_exclusive_group(required=True)
60+
config_group.add_argument(
61+
"--model-path",
62+
type=str,
63+
help="Path or name of HuggingFace model to load config from",
64+
)
65+
config_group.add_argument(
66+
"--manual-config",
67+
action="store_true",
68+
help="Use manual configuration (requires --num-experts, etc.)",
69+
)
70+
71+
# Manual configuration options
72+
parser.add_argument("--num-experts", type=int, help="Number of experts")
73+
parser.add_argument("--num-experts-per-tok", type=int, help="Top-k value")
74+
parser.add_argument("--hidden-size", type=int, help="Hidden dimension")
75+
parser.add_argument("--intermediate-size", type=int, help="Intermediate dimension")
76+
parser.add_argument(
77+
"--activation",
78+
type=str,
79+
default="silu",
80+
choices=["silu", "gelu", "swigluoai"],
81+
help="Activation function",
82+
)
83+
84+
# Distributed configuration
85+
parser.add_argument(
86+
"--ep-size",
87+
type=int,
88+
default=1,
89+
help="Expert parallel size (default: 1)",
90+
)
91+
parser.add_argument(
92+
"--tp-size",
93+
type=int,
94+
default=1,
95+
help="Total number of devices to use (default: 1)",
96+
)
97+
parser.add_argument(
98+
"--dist-init-addr",
99+
type=str,
100+
help="Distributed initialization address (e.g., 10.0.0.1:12345)",
101+
)
102+
parser.add_argument(
103+
"--nnodes",
104+
type=int,
105+
default=1,
106+
help="Number of nodes (default: 1)",
107+
)
108+
parser.add_argument(
109+
"--node-rank",
110+
type=int,
111+
default=0,
112+
help="Current node rank (default: 0)",
113+
)
114+
115+
# Benchmark parameters
116+
parser.add_argument(
117+
"--num-tokens",
118+
type=int,
119+
nargs="+",
120+
default=[512, 1024, 2048],
121+
help="List of token counts to test (default: 512 1024 2048)",
122+
)
123+
parser.add_argument(
124+
"--scenarios",
125+
type=str,
126+
nargs="+",
127+
default=["random", "balanced", "imbalanced"],
128+
choices=["random", "balanced", "imbalanced"],
129+
help="Scenarios to test (default: all)",
130+
)
131+
parser.add_argument(
132+
"--imbalance-factor",
133+
type=float,
134+
default=3.0,
135+
help="Target imbalance factor for 'imbalanced' scenario (default: 3.0)",
136+
)
137+
parser.add_argument(
138+
"--warmup-iters",
139+
type=int,
140+
default=1,
141+
help="Warmup iterations (default: 1, only need one for JAX JIT)",
142+
)
143+
parser.add_argument(
144+
"--benchmark-iters",
145+
type=int,
146+
default=10,
147+
help="Benchmark iterations (default: 10)",
148+
)
149+
150+
# Profiling
151+
parser.add_argument(
152+
"--profile",
153+
action="store_true",
154+
help="Enable JAX profiler",
155+
)
156+
parser.add_argument(
157+
"--profile-dir",
158+
type=str,
159+
default="./profiles",
160+
help="Profile output directory (default: ./profiles)",
161+
)
162+
163+
# Output
164+
parser.add_argument(
165+
"--output-format",
166+
type=str,
167+
default="both",
168+
choices=["csv", "markdown", "both"],
169+
help="Output format (default: both)",
170+
)
171+
parser.add_argument(
172+
"--output-file",
173+
type=str,
174+
default="./benchmark_results",
175+
help="Output file base path (default: ./benchmark_results)",
176+
)
177+
parser.add_argument(
178+
"--verbose",
179+
action="store_true",
180+
help="Enable verbose logging",
181+
)
182+
183+
args = parser.parse_args()
184+
185+
# Validate manual config
186+
if args.manual_config:
187+
required_manual = ["num_experts", "num_experts_per_tok", "hidden_size", "intermediate_size"]
188+
missing = [arg for arg in required_manual if getattr(args, arg) is None]
189+
if missing:
190+
parser.error(
191+
f"--manual-config requires: {', '.join('--' + m.replace('_', '-') for m in missing)}"
192+
)
193+
194+
return args
195+
196+
197+
def setup_distributed(args: argparse.Namespace) -> None:
198+
"""Initialize JAX distributed environment if needed."""
199+
if args.nnodes > 1:
200+
if not args.dist_init_addr:
201+
raise ValueError("--dist-init-addr is required for multi-node setup")
202+
203+
print(f"Initializing distributed: nnodes={args.nnodes}, rank={args.node_rank}")
204+
jax.distributed.initialize(
205+
coordinator_address=args.dist_init_addr,
206+
num_processes=args.nnodes,
207+
process_id=args.node_rank,
208+
)
209+
print(f"Distributed initialized. Process rank: {jax.process_index()}")
210+
211+
212+
def create_mesh(tp_size: int) -> jax.sharding.Mesh:
213+
"""
214+
Create JAX mesh for MoE execution using create_device_mesh.
215+
216+
This follows the same logic as scheduler.py. The MoE layers (FusedEPMoE and EPMoE)
217+
will internally compute world_size from the mesh and calculate the actual tensor
218+
parallel size as: tp_actual = world_size // ep_size
219+
220+
Args:
221+
tp_size: Total number of devices to use
222+
223+
Returns:
224+
JAX mesh with (data, tensor) axes
225+
"""
226+
from sgl_jax.srt.utils.mesh_utils import create_device_mesh
227+
228+
mesh = create_device_mesh(
229+
ici_parallelism=[-1, tp_size],
230+
dcn_parallelism=[1, 1],
231+
)
232+
233+
return mesh
234+
235+
236+
def load_or_create_config(args: argparse.Namespace) -> MoEBenchmarkConfig:
237+
"""Load configuration from model path or create from manual args."""
238+
if args.model_path:
239+
print(f"Loading config from model: {args.model_path}")
240+
config = MoEBenchmarkConfig.from_model_path(
241+
args.model_path,
242+
ep_size=args.ep_size,
243+
tp_size=args.tp_size,
244+
)
245+
else:
246+
print("Using manual configuration")
247+
config = MoEBenchmarkConfig(
248+
num_experts=args.num_experts,
249+
num_experts_per_tok=args.num_experts_per_tok,
250+
hidden_size=args.hidden_size,
251+
intermediate_size=args.intermediate_size,
252+
activation=args.activation,
253+
ep_size=args.ep_size,
254+
tp_size=args.tp_size,
255+
)
256+
257+
# Validate config
258+
config.validate()
259+
260+
if args.verbose:
261+
print("\n" + str(config))
262+
263+
return config
264+
265+
266+
def main():
267+
"""Main execution flow."""
268+
args = parse_args()
269+
270+
print("=" * 80)
271+
print("MoE Benchmark: FusedEPMoE vs EPMoE")
272+
print("=" * 80)
273+
274+
# Setup distributed
275+
setup_distributed(args)
276+
277+
# Create mesh
278+
print(f"\nCreating JAX mesh: tp_size={args.tp_size}, ep_size={args.ep_size}")
279+
mesh = create_mesh(args.tp_size)
280+
print(f"Mesh created with {len(mesh.devices.flatten())} devices")
281+
print(f"Mesh shape: {mesh.shape}")
282+
283+
# Load configuration
284+
config = load_or_create_config(args)
285+
286+
# Generate synthetic weights
287+
print("\nGenerating synthetic weights...")
288+
fused_weights, epmoe_weights = create_synthetic_weights(config, mesh)
289+
print(f"Weights generated: w1={fused_weights['w1'].shape}, w2={fused_weights['w2'].shape}")
290+
291+
# Initialize benchmark runner
292+
print("\nInitializing benchmark runner...")
293+
runner = MoEBenchmarkRunner(
294+
config=config,
295+
mesh=mesh,
296+
warmup_iters=args.warmup_iters,
297+
benchmark_iters=args.benchmark_iters,
298+
verbose=args.verbose,
299+
)
300+
301+
runner.initialize_layers(fused_weights, epmoe_weights)
302+
303+
# Run benchmarks
304+
print("\n" + "=" * 80)
305+
print("Running Benchmarks")
306+
print("=" * 80)
307+
308+
all_results = []
309+
310+
for scenario in args.scenarios:
311+
for num_tokens in args.num_tokens:
312+
print(f"\n{'=' * 80}")
313+
print(f"Scenario: {scenario}, Tokens: {num_tokens}")
314+
print(f"{'=' * 80}")
315+
316+
if args.profile:
317+
# Profile each implementation separately
318+
profile_dir_fused = os.path.join(
319+
args.profile_dir, f"{scenario}_tokens{num_tokens}_fused"
320+
)
321+
profile_dir_epmoe = os.path.join(
322+
args.profile_dir, f"{scenario}_tokens{num_tokens}_epmoe"
323+
)
324+
325+
os.makedirs(profile_dir_fused, exist_ok=True)
326+
os.makedirs(profile_dir_epmoe, exist_ok=True)
327+
328+
print(f"Profiling enabled: {profile_dir_fused}, {profile_dir_epmoe}")
329+
330+
# Run with profiling
331+
jax.profiler.start_trace(profile_dir_fused)
332+
fused_result, _ = runner.benchmark_scenario(
333+
scenario, num_tokens, args.imbalance_factor
334+
)
335+
jax.profiler.stop_trace()
336+
337+
jax.profiler.start_trace(profile_dir_epmoe)
338+
_, epmoe_result = runner.benchmark_scenario(
339+
scenario, num_tokens, args.imbalance_factor
340+
)
341+
jax.profiler.stop_trace()
342+
343+
all_results.extend([fused_result, epmoe_result])
344+
345+
else:
346+
# Run without profiling
347+
fused_result, epmoe_result = runner.benchmark_scenario(
348+
scenario, num_tokens, args.imbalance_factor
349+
)
350+
all_results.extend([fused_result, epmoe_result])
351+
352+
# Print summary
353+
speedup = epmoe_result.latency_mean / fused_result.latency_mean
354+
print("\nResults:")
355+
print(f" FusedEPMoE: {fused_result.latency_mean:.4f} ms (mean)")
356+
print(f" EPMoE: {epmoe_result.latency_mean:.4f} ms (mean)")
357+
print(f" Speedup: {speedup:.2f}x")
358+
print(f" Imbalance: {fused_result.max_imbalance:.2f}x")
359+
360+
# Save results
361+
print("\n" + "=" * 80)
362+
print("Saving Results")
363+
print("=" * 80)
364+
365+
save_results(all_results, args.output_file, args.output_format)
366+
367+
print("\n" + "=" * 80)
368+
print("Benchmark Complete!")
369+
print("=" * 80)
370+
371+
372+
if __name__ == "__main__":
373+
main()

0 commit comments

Comments
 (0)