|
| 1 | +"""Benchmark script comparing FusedEPMoE vs EPMoE implementations. |
| 2 | +
|
| 3 | +This script performs layer-level benchmarking with synthetic weights and controlled |
| 4 | +token distribution scenarios (random, balanced, imbalanced). |
| 5 | +
|
| 6 | +Example usage: |
| 7 | + # Quick test |
| 8 | + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ |
| 9 | + --num-experts 8 --num-experts-per-tok 2 \ |
| 10 | + --hidden-size 1024 --intermediate-size 4096 \ |
| 11 | + --num-tokens 512 --scenarios random |
| 12 | +
|
| 13 | + # Using HF model config |
| 14 | + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ |
| 15 | + --model-path Qwen/Qwen2.5-MoE-A2.7B \ |
| 16 | + --ep-size 8 --num-tokens 1024 2048 4096 \ |
| 17 | + --scenarios random balanced imbalanced \ |
| 18 | + --profile --profile-dir ./profiles/qwen |
| 19 | +
|
| 20 | + # 4 GPUs with expert parallelism (tp=4, ep=4, tp_actual=1) |
| 21 | + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ |
| 22 | + --model-path Qwen/Qwen2.5-MoE-A2.7B \ |
| 23 | + --tp-size 4 --ep-size 4 |
| 24 | +
|
| 25 | + # 8 GPUs with expert and tensor parallelism (tp=8, ep=4, tp_actual=2) |
| 26 | + python benchmark/fused_moe/bench_fused_vs_epmoe.py \ |
| 27 | + --model-path Qwen/Qwen2.5-MoE-A2.7B \ |
| 28 | + --tp-size 8 --ep-size 4 |
| 29 | +""" |
| 30 | + |
| 31 | +import argparse |
| 32 | +import os |
| 33 | +import sys |
| 34 | + |
| 35 | +import jax |
| 36 | + |
| 37 | +# Add python directory to path for imports |
| 38 | +benchmark_dir = os.path.dirname(os.path.abspath(__file__)) # benchmark/fused_moe |
| 39 | +benchmark_root = os.path.dirname(benchmark_dir) # benchmark |
| 40 | +project_root = os.path.dirname(benchmark_root) # sgl-jax |
| 41 | +python_dir = os.path.join(project_root, "python") # sgl-jax/python |
| 42 | +sys.path.insert(0, python_dir) |
| 43 | +sys.path.insert(0, project_root) # For benchmark imports |
| 44 | + |
| 45 | +from benchmark.fused_moe.benchmark_runner import MoEBenchmarkRunner # noqa: E402 |
| 46 | +from benchmark.fused_moe.config_utils import MoEBenchmarkConfig # noqa: E402 |
| 47 | +from benchmark.fused_moe.output_formatter import save_results # noqa: E402 |
| 48 | +from benchmark.fused_moe.synthetic_data import create_synthetic_weights # noqa: E402 |
| 49 | + |
| 50 | + |
| 51 | +def parse_args() -> argparse.Namespace: |
| 52 | + """Parse command-line arguments.""" |
| 53 | + parser = argparse.ArgumentParser( |
| 54 | + description="Benchmark FusedEPMoE vs EPMoE implementations", |
| 55 | + formatter_class=argparse.RawDescriptionHelpFormatter, |
| 56 | + ) |
| 57 | + |
| 58 | + # Model configuration (mutually exclusive) |
| 59 | + config_group = parser.add_mutually_exclusive_group(required=True) |
| 60 | + config_group.add_argument( |
| 61 | + "--model-path", |
| 62 | + type=str, |
| 63 | + help="Path or name of HuggingFace model to load config from", |
| 64 | + ) |
| 65 | + config_group.add_argument( |
| 66 | + "--manual-config", |
| 67 | + action="store_true", |
| 68 | + help="Use manual configuration (requires --num-experts, etc.)", |
| 69 | + ) |
| 70 | + |
| 71 | + # Manual configuration options |
| 72 | + parser.add_argument("--num-experts", type=int, help="Number of experts") |
| 73 | + parser.add_argument("--num-experts-per-tok", type=int, help="Top-k value") |
| 74 | + parser.add_argument("--hidden-size", type=int, help="Hidden dimension") |
| 75 | + parser.add_argument("--intermediate-size", type=int, help="Intermediate dimension") |
| 76 | + parser.add_argument( |
| 77 | + "--activation", |
| 78 | + type=str, |
| 79 | + default="silu", |
| 80 | + choices=["silu", "gelu", "swigluoai"], |
| 81 | + help="Activation function", |
| 82 | + ) |
| 83 | + |
| 84 | + # Distributed configuration |
| 85 | + parser.add_argument( |
| 86 | + "--ep-size", |
| 87 | + type=int, |
| 88 | + default=1, |
| 89 | + help="Expert parallel size (default: 1)", |
| 90 | + ) |
| 91 | + parser.add_argument( |
| 92 | + "--tp-size", |
| 93 | + type=int, |
| 94 | + default=1, |
| 95 | + help="Total number of devices to use (default: 1)", |
| 96 | + ) |
| 97 | + parser.add_argument( |
| 98 | + "--dist-init-addr", |
| 99 | + type=str, |
| 100 | + help="Distributed initialization address (e.g., 10.0.0.1:12345)", |
| 101 | + ) |
| 102 | + parser.add_argument( |
| 103 | + "--nnodes", |
| 104 | + type=int, |
| 105 | + default=1, |
| 106 | + help="Number of nodes (default: 1)", |
| 107 | + ) |
| 108 | + parser.add_argument( |
| 109 | + "--node-rank", |
| 110 | + type=int, |
| 111 | + default=0, |
| 112 | + help="Current node rank (default: 0)", |
| 113 | + ) |
| 114 | + |
| 115 | + # Benchmark parameters |
| 116 | + parser.add_argument( |
| 117 | + "--num-tokens", |
| 118 | + type=int, |
| 119 | + nargs="+", |
| 120 | + default=[512, 1024, 2048], |
| 121 | + help="List of token counts to test (default: 512 1024 2048)", |
| 122 | + ) |
| 123 | + parser.add_argument( |
| 124 | + "--scenarios", |
| 125 | + type=str, |
| 126 | + nargs="+", |
| 127 | + default=["random", "balanced", "imbalanced"], |
| 128 | + choices=["random", "balanced", "imbalanced"], |
| 129 | + help="Scenarios to test (default: all)", |
| 130 | + ) |
| 131 | + parser.add_argument( |
| 132 | + "--imbalance-factor", |
| 133 | + type=float, |
| 134 | + default=3.0, |
| 135 | + help="Target imbalance factor for 'imbalanced' scenario (default: 3.0)", |
| 136 | + ) |
| 137 | + parser.add_argument( |
| 138 | + "--warmup-iters", |
| 139 | + type=int, |
| 140 | + default=1, |
| 141 | + help="Warmup iterations (default: 1, only need one for JAX JIT)", |
| 142 | + ) |
| 143 | + parser.add_argument( |
| 144 | + "--benchmark-iters", |
| 145 | + type=int, |
| 146 | + default=10, |
| 147 | + help="Benchmark iterations (default: 10)", |
| 148 | + ) |
| 149 | + |
| 150 | + # Profiling |
| 151 | + parser.add_argument( |
| 152 | + "--profile", |
| 153 | + action="store_true", |
| 154 | + help="Enable JAX profiler", |
| 155 | + ) |
| 156 | + parser.add_argument( |
| 157 | + "--profile-dir", |
| 158 | + type=str, |
| 159 | + default="./profiles", |
| 160 | + help="Profile output directory (default: ./profiles)", |
| 161 | + ) |
| 162 | + |
| 163 | + # Output |
| 164 | + parser.add_argument( |
| 165 | + "--output-format", |
| 166 | + type=str, |
| 167 | + default="both", |
| 168 | + choices=["csv", "markdown", "both"], |
| 169 | + help="Output format (default: both)", |
| 170 | + ) |
| 171 | + parser.add_argument( |
| 172 | + "--output-file", |
| 173 | + type=str, |
| 174 | + default="./benchmark_results", |
| 175 | + help="Output file base path (default: ./benchmark_results)", |
| 176 | + ) |
| 177 | + parser.add_argument( |
| 178 | + "--verbose", |
| 179 | + action="store_true", |
| 180 | + help="Enable verbose logging", |
| 181 | + ) |
| 182 | + |
| 183 | + args = parser.parse_args() |
| 184 | + |
| 185 | + # Validate manual config |
| 186 | + if args.manual_config: |
| 187 | + required_manual = ["num_experts", "num_experts_per_tok", "hidden_size", "intermediate_size"] |
| 188 | + missing = [arg for arg in required_manual if getattr(args, arg) is None] |
| 189 | + if missing: |
| 190 | + parser.error( |
| 191 | + f"--manual-config requires: {', '.join('--' + m.replace('_', '-') for m in missing)}" |
| 192 | + ) |
| 193 | + |
| 194 | + return args |
| 195 | + |
| 196 | + |
| 197 | +def setup_distributed(args: argparse.Namespace) -> None: |
| 198 | + """Initialize JAX distributed environment if needed.""" |
| 199 | + if args.nnodes > 1: |
| 200 | + if not args.dist_init_addr: |
| 201 | + raise ValueError("--dist-init-addr is required for multi-node setup") |
| 202 | + |
| 203 | + print(f"Initializing distributed: nnodes={args.nnodes}, rank={args.node_rank}") |
| 204 | + jax.distributed.initialize( |
| 205 | + coordinator_address=args.dist_init_addr, |
| 206 | + num_processes=args.nnodes, |
| 207 | + process_id=args.node_rank, |
| 208 | + ) |
| 209 | + print(f"Distributed initialized. Process rank: {jax.process_index()}") |
| 210 | + |
| 211 | + |
| 212 | +def create_mesh(tp_size: int) -> jax.sharding.Mesh: |
| 213 | + """ |
| 214 | + Create JAX mesh for MoE execution using create_device_mesh. |
| 215 | +
|
| 216 | + This follows the same logic as scheduler.py. The MoE layers (FusedEPMoE and EPMoE) |
| 217 | + will internally compute world_size from the mesh and calculate the actual tensor |
| 218 | + parallel size as: tp_actual = world_size // ep_size |
| 219 | +
|
| 220 | + Args: |
| 221 | + tp_size: Total number of devices to use |
| 222 | +
|
| 223 | + Returns: |
| 224 | + JAX mesh with (data, tensor) axes |
| 225 | + """ |
| 226 | + from sgl_jax.srt.utils.mesh_utils import create_device_mesh |
| 227 | + |
| 228 | + mesh = create_device_mesh( |
| 229 | + ici_parallelism=[-1, tp_size], |
| 230 | + dcn_parallelism=[1, 1], |
| 231 | + ) |
| 232 | + |
| 233 | + return mesh |
| 234 | + |
| 235 | + |
| 236 | +def load_or_create_config(args: argparse.Namespace) -> MoEBenchmarkConfig: |
| 237 | + """Load configuration from model path or create from manual args.""" |
| 238 | + if args.model_path: |
| 239 | + print(f"Loading config from model: {args.model_path}") |
| 240 | + config = MoEBenchmarkConfig.from_model_path( |
| 241 | + args.model_path, |
| 242 | + ep_size=args.ep_size, |
| 243 | + tp_size=args.tp_size, |
| 244 | + ) |
| 245 | + else: |
| 246 | + print("Using manual configuration") |
| 247 | + config = MoEBenchmarkConfig( |
| 248 | + num_experts=args.num_experts, |
| 249 | + num_experts_per_tok=args.num_experts_per_tok, |
| 250 | + hidden_size=args.hidden_size, |
| 251 | + intermediate_size=args.intermediate_size, |
| 252 | + activation=args.activation, |
| 253 | + ep_size=args.ep_size, |
| 254 | + tp_size=args.tp_size, |
| 255 | + ) |
| 256 | + |
| 257 | + # Validate config |
| 258 | + config.validate() |
| 259 | + |
| 260 | + if args.verbose: |
| 261 | + print("\n" + str(config)) |
| 262 | + |
| 263 | + return config |
| 264 | + |
| 265 | + |
| 266 | +def main(): |
| 267 | + """Main execution flow.""" |
| 268 | + args = parse_args() |
| 269 | + |
| 270 | + print("=" * 80) |
| 271 | + print("MoE Benchmark: FusedEPMoE vs EPMoE") |
| 272 | + print("=" * 80) |
| 273 | + |
| 274 | + # Setup distributed |
| 275 | + setup_distributed(args) |
| 276 | + |
| 277 | + # Create mesh |
| 278 | + print(f"\nCreating JAX mesh: tp_size={args.tp_size}, ep_size={args.ep_size}") |
| 279 | + mesh = create_mesh(args.tp_size) |
| 280 | + print(f"Mesh created with {len(mesh.devices.flatten())} devices") |
| 281 | + print(f"Mesh shape: {mesh.shape}") |
| 282 | + |
| 283 | + # Load configuration |
| 284 | + config = load_or_create_config(args) |
| 285 | + |
| 286 | + # Generate synthetic weights |
| 287 | + print("\nGenerating synthetic weights...") |
| 288 | + fused_weights, epmoe_weights = create_synthetic_weights(config, mesh) |
| 289 | + print(f"Weights generated: w1={fused_weights['w1'].shape}, w2={fused_weights['w2'].shape}") |
| 290 | + |
| 291 | + # Initialize benchmark runner |
| 292 | + print("\nInitializing benchmark runner...") |
| 293 | + runner = MoEBenchmarkRunner( |
| 294 | + config=config, |
| 295 | + mesh=mesh, |
| 296 | + warmup_iters=args.warmup_iters, |
| 297 | + benchmark_iters=args.benchmark_iters, |
| 298 | + verbose=args.verbose, |
| 299 | + ) |
| 300 | + |
| 301 | + runner.initialize_layers(fused_weights, epmoe_weights) |
| 302 | + |
| 303 | + # Run benchmarks |
| 304 | + print("\n" + "=" * 80) |
| 305 | + print("Running Benchmarks") |
| 306 | + print("=" * 80) |
| 307 | + |
| 308 | + all_results = [] |
| 309 | + |
| 310 | + for scenario in args.scenarios: |
| 311 | + for num_tokens in args.num_tokens: |
| 312 | + print(f"\n{'=' * 80}") |
| 313 | + print(f"Scenario: {scenario}, Tokens: {num_tokens}") |
| 314 | + print(f"{'=' * 80}") |
| 315 | + |
| 316 | + if args.profile: |
| 317 | + # Profile each implementation separately |
| 318 | + profile_dir_fused = os.path.join( |
| 319 | + args.profile_dir, f"{scenario}_tokens{num_tokens}_fused" |
| 320 | + ) |
| 321 | + profile_dir_epmoe = os.path.join( |
| 322 | + args.profile_dir, f"{scenario}_tokens{num_tokens}_epmoe" |
| 323 | + ) |
| 324 | + |
| 325 | + os.makedirs(profile_dir_fused, exist_ok=True) |
| 326 | + os.makedirs(profile_dir_epmoe, exist_ok=True) |
| 327 | + |
| 328 | + print(f"Profiling enabled: {profile_dir_fused}, {profile_dir_epmoe}") |
| 329 | + |
| 330 | + # Run with profiling |
| 331 | + jax.profiler.start_trace(profile_dir_fused) |
| 332 | + fused_result, _ = runner.benchmark_scenario( |
| 333 | + scenario, num_tokens, args.imbalance_factor |
| 334 | + ) |
| 335 | + jax.profiler.stop_trace() |
| 336 | + |
| 337 | + jax.profiler.start_trace(profile_dir_epmoe) |
| 338 | + _, epmoe_result = runner.benchmark_scenario( |
| 339 | + scenario, num_tokens, args.imbalance_factor |
| 340 | + ) |
| 341 | + jax.profiler.stop_trace() |
| 342 | + |
| 343 | + all_results.extend([fused_result, epmoe_result]) |
| 344 | + |
| 345 | + else: |
| 346 | + # Run without profiling |
| 347 | + fused_result, epmoe_result = runner.benchmark_scenario( |
| 348 | + scenario, num_tokens, args.imbalance_factor |
| 349 | + ) |
| 350 | + all_results.extend([fused_result, epmoe_result]) |
| 351 | + |
| 352 | + # Print summary |
| 353 | + speedup = epmoe_result.latency_mean / fused_result.latency_mean |
| 354 | + print("\nResults:") |
| 355 | + print(f" FusedEPMoE: {fused_result.latency_mean:.4f} ms (mean)") |
| 356 | + print(f" EPMoE: {epmoe_result.latency_mean:.4f} ms (mean)") |
| 357 | + print(f" Speedup: {speedup:.2f}x") |
| 358 | + print(f" Imbalance: {fused_result.max_imbalance:.2f}x") |
| 359 | + |
| 360 | + # Save results |
| 361 | + print("\n" + "=" * 80) |
| 362 | + print("Saving Results") |
| 363 | + print("=" * 80) |
| 364 | + |
| 365 | + save_results(all_results, args.output_file, args.output_format) |
| 366 | + |
| 367 | + print("\n" + "=" * 80) |
| 368 | + print("Benchmark Complete!") |
| 369 | + print("=" * 80) |
| 370 | + |
| 371 | + |
| 372 | +if __name__ == "__main__": |
| 373 | + main() |
0 commit comments