From 0702131180448f3263d46a4cd278613f6e0107f9 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Mon, 23 Feb 2026 00:55:16 -0600 Subject: [PATCH 1/3] add example to use autosp --- benchmarks/autosp/.gitignore | 5 + benchmarks/autosp/README.md | 58 ++ benchmarks/autosp/configs/autosp_config.json | 19 + benchmarks/autosp/configs/autosp_config.yaml | 16 + .../autosp/configs/torchcompile_config.json | 14 + .../autosp/configs/torchcompile_config.yaml | 16 + benchmarks/autosp/distributed_attention.py | 93 +++ benchmarks/autosp/ring_attention.py | 530 ++++++++++++++++++ benchmarks/autosp/run.py | 358 ++++++++++++ benchmarks/autosp/run_autosp.sh | 120 ++++ benchmarks/autosp/sp_dp_registry.py | 45 ++ 11 files changed, 1274 insertions(+) create mode 100644 benchmarks/autosp/.gitignore create mode 100644 benchmarks/autosp/README.md create mode 100644 benchmarks/autosp/configs/autosp_config.json create mode 100644 benchmarks/autosp/configs/autosp_config.yaml create mode 100644 benchmarks/autosp/configs/torchcompile_config.json create mode 100644 benchmarks/autosp/configs/torchcompile_config.yaml create mode 100644 benchmarks/autosp/distributed_attention.py create mode 100644 benchmarks/autosp/ring_attention.py create mode 100644 benchmarks/autosp/run.py create mode 100755 benchmarks/autosp/run_autosp.sh create mode 100644 benchmarks/autosp/sp_dp_registry.py diff --git a/benchmarks/autosp/.gitignore b/benchmarks/autosp/.gitignore new file mode 100644 index 000000000..1d27669c4 --- /dev/null +++ b/benchmarks/autosp/.gitignore @@ -0,0 +1,5 @@ +*.log +*.pyc +logs +*. +*.pt diff --git a/benchmarks/autosp/README.md b/benchmarks/autosp/README.md new file mode 100644 index 000000000..5c7a13eb6 --- /dev/null +++ b/benchmarks/autosp/README.md @@ -0,0 +1,58 @@ +# AutoSP Setup Guide + +Quick start guide to clone and set up the AutoSP repository. + +## Prerequisites + +- CUDA 12.8 compatible GPU (recommended) +- Conda installed +- Git + + +### Install dependencies + +```bash +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 +``` + +```bash +pip install \ + transformers==4.50.3 \ + tokenizers \ + huggingface-hub \ + safetensors \ + datasets \ + accelerate \ + scipy \ + tqdm \ + pyyaml +``` + +### Install DeepSpeed + +```bash +pip install --no-build-isolation git+https://github.com/neeldani/DeepSpeed.git@autosp +``` + +## Benchmarking + +See `benchmarks/autosp/` directory for benchmarking scripts: + +```bash +cd benchmarks/autosp +``` + +#### Run autosp on 2 GPUs +```bash +./run_autosp.sh --compile autosp --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` + +#### Run eager mode ulysses on 2 GPUs +```bash +./run_autosp.sh --compile eager --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` + +#### Run torch.compile'd ulysses on 2 GPUs +```bash +./run_autosp.sh --compile compile --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic +``` diff --git a/benchmarks/autosp/configs/autosp_config.json b/benchmarks/autosp/configs/autosp_config.json new file mode 100644 index 000000000..93deb7402 --- /dev/null +++ b/benchmarks/autosp/configs/autosp_config.json @@ -0,0 +1,19 @@ +{ + + "bf16": { + "enabled": true + }, + + "zero_optimization":{ + "stage": 0 + }, + "compile": { + "deepcompile": true + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/benchmarks/autosp/configs/autosp_config.yaml b/benchmarks/autosp/configs/autosp_config.yaml new file mode 100644 index 000000000..5ba20b9a6 --- /dev/null +++ b/benchmarks/autosp/configs/autosp_config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/autosp_config.json +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/benchmarks/autosp/configs/torchcompile_config.json b/benchmarks/autosp/configs/torchcompile_config.json new file mode 100644 index 000000000..d61b17b9f --- /dev/null +++ b/benchmarks/autosp/configs/torchcompile_config.json @@ -0,0 +1,14 @@ +{ + "bf16": { + "enabled": true + }, + "zero_optimization":{ + "stage": 0 + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/benchmarks/autosp/configs/torchcompile_config.yaml b/benchmarks/autosp/configs/torchcompile_config.yaml new file mode 100644 index 000000000..cebc281c2 --- /dev/null +++ b/benchmarks/autosp/configs/torchcompile_config.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + deepspeed_config_file: configs/torchcompile_config.json +distributed_type: DEEPSPEED +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 2 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/benchmarks/autosp/distributed_attention.py b/benchmarks/autosp/distributed_attention.py new file mode 100644 index 000000000..b9f9667e1 --- /dev/null +++ b/benchmarks/autosp/distributed_attention.py @@ -0,0 +1,93 @@ +import os +import torch +import torch.distributed as dist +from deepspeed.sequence.layer import DistributedAttention +from sp_dp_registry import get_group, is_setup, sp_size + +#TODO: See if there is a better way to pass the mask +_padding_mask_context = None + +def set_padding_mask(mask): + global _padding_mask_context + _padding_mask_context = mask + +def get_padding_mask(): + global _padding_mask_context + return _padding_mask_context + +def ulysses_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=None, + dropout=0.0, + is_causal=False, + **kwargs, +): + assert is_setup(), 'Incorrectly setup SP/DP Groups.' + + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + # Ulysses expects (batch, seq, heads, dim) + # HF standard provides (batch, heads, seq, dim) + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + + if not hasattr(self, "ulysses_engine"): + self.ulysses_engine = DistributedAttention( + sdpa_wrapper, + group, + scatter_idx=2, # Shard heads + gather_idx=1 # Gather sequences + ) + + attn_output = self.ulysses_engine( + q, k, v, + batch_dim_idx=0, + attn_mask=None, + dropout_p=dropout, + is_causal=is_causal, + scale=scaling + ) + + return attn_output, None + +def sdpa_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None): + # Permute from [b, s, n, h] to [b, n, s, h] for SDPA + q = query.permute(0, 2, 1, 3).contiguous() + k = key.permute(0, 2, 1, 3).contiguous() + v = value.permute(0, 2, 1, 3).contiguous() + + # Create the attention mask from padding mask + causal mask + padding_mask = get_padding_mask() + combined_mask = None + + if padding_mask is not None: + B, S = padding_mask.shape # [B, S] + device = padding_mask.device + + causal_mask = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool)) + padding_mask_bool = (padding_mask != 0).unsqueeze(1) + causal_expanded = causal_mask.unsqueeze(0) + combined_mask = causal_expanded & padding_mask_bool + combined_mask = combined_mask.unsqueeze(1) + + elif is_causal: + pass + + output = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=combined_mask, + dropout_p=dropout_p, + is_causal=(combined_mask is None and is_causal), + scale=scale, + enable_gqa=False + ) + + # Permute back from [b, n, s, h] to [b, s, n, h] for all-to-all on output + output = output.permute(0, 2, 1, 3).contiguous() + return output diff --git a/benchmarks/autosp/ring_attention.py b/benchmarks/autosp/ring_attention.py new file mode 100644 index 000000000..7b01da7b9 --- /dev/null +++ b/benchmarks/autosp/ring_attention.py @@ -0,0 +1,530 @@ +## Code is taken directly from the RingFlashAttention +## repository: https://github.com/zhuzilin/ring-flash-attention +from typing import Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import inspect +from functools import cache + +from sp_dp_registry import get_group, is_setup, sp_size +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward + +__all__ = ["update_out_and_lse", "RingComm", "get_default_args"] + +## Utility communication files. ## +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +@torch.jit.script +def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + # new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + # torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + # For additional context and discussion, please refer to: + # https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + + return out, lse + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +@torch.jit.script +def flatten_varlen_lse(lse, cu_seqlens): + new_lse = [] + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse.append(lse[i, :, : end - start]) + return torch.cat(new_lse, dim=1) + + +@torch.jit.script +def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): + num_seq = len(cu_seqlens) - 1 + num_head = lse.shape[-2] + new_lse = torch.empty( + (num_seq, max_seqlen, num_head, 1), dtype=torch.float32, device=lse.device + ) + for i in range(num_seq): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + new_lse[i, : end - start] = lse[start:end] + return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous() + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + comm = RingComm(process_group) + + out = None + lse = None + + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + + if not causal or step <= comm.rank: + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal and step == 0, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + + block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + next_dk, next_dv = None, None + next_k, next_v = None, None + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + if step <= kv_comm.rank or not causal: + bwd_causal = causal and step == 0 + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout, + "q": q, + "k": k, + "v": v, + "out": out, + "softmax_lse": softmax_lse, + "dq": block_dq_buffer, + "dk": block_dk_buffer, + "dv": block_dv_buffer, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": bwd_causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + if dq is None: + dq = block_dq_buffer.to(torch.float32) + dk = block_dk_buffer.to(torch.float32) + dv = block_dv_buffer.to(torch.float32) + else: + dq += block_dq_buffer + d_kv_comm.wait() + dk = block_dk_buffer + next_dk + dv = block_dv_buffer + next_dv + elif step != 0: + d_kv_comm.wait() + dk, dv = next_dk, next_dv + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv(dk, dv) + + d_kv_comm.wait() + + return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) + + +class RingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +# HuggingFace-compatible wrapper for ring attention +# This follows the same pattern as ulysses_attention_forward in distributed_attention.py +def ring_attention_forward( + self, # This will be the LlamaAttention instance + query_states, + key_states, + value_states, + attention_mask=None, + scaling=None, + dropout=0.0, + is_causal=True, + **kwargs, +): + """ + Ring attention forward pass compatible with HuggingFace's attention interface. + + Args: + self: The LlamaAttention module instance + query_states: (batch, heads, seq, dim) - HuggingFace format + key_states: (batch, heads, seq, dim) - HuggingFace format + value_states: (batch, heads, seq, dim) - HuggingFace format + attention_mask: Not used (ring attention handles masking internally) + scaling: Softmax scaling factor + dropout: Dropout probability + is_causal: Whether to use causal masking + **kwargs: Additional arguments (ignored) + + Returns: + tuple: (attn_output, None) where attn_output is (batch, seq, heads, dim) + """ + # Convert from HF format (batch, heads, seq, dim) to flash_attn format (batch, seq, heads, dim) + assert is_setup(), 'Incorrectly setup SP/DP Groups.' + + gid = dist.get_rank() // sp_size() + group = get_group(gid) + + q = query_states.transpose(1, 2).contiguous() + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + + # Ring attention expects (batch, seq, heads, dim) + # Call the ring flash attention function + attn_output = ring_flash_attn_func( + q, + k, + v, + dropout_p=dropout, + softmax_scale=scaling, + causal=is_causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=group, + ) + + # Output is already in (batch, seq, heads, dim) format, which HF expects after attention + # Note: Llama's forward handles the reshape internally + return attn_output, None diff --git a/benchmarks/autosp/run.py b/benchmarks/autosp/run.py new file mode 100644 index 000000000..35aac1e5d --- /dev/null +++ b/benchmarks/autosp/run.py @@ -0,0 +1,358 @@ +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +import argparse +import random +import time +from datetime import datetime + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from accelerate import Accelerator +from datasets import load_dataset +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, enable_full_determinism + +from deepspeed.compile.passes.sp_compile import prepare_autosp_inputs + +from distributed_attention import ulysses_attention_forward, set_padding_mask +# from ring_attention import ring_attention_forward +from sp_dp_registry import get_group, populate_registry, get_registry + +torch.set_float32_matmul_precision("high") + +def set_seed(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def seed_worker(worker_id): + worker_seed = 12 + worker_id + np.random.seed(worker_seed) + random.seed(worker_seed) + +def get_args(): + parser = argparse.ArgumentParser( + description="AutoSP benchmark script for distributed sequence parallel training", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + + parser.add_argument( + "--model_name", + type=str, + default="meta-llama/Llama-2-7b-hf", + help="HuggingFace model name or path" + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size per GPU" + ) + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of training epochs" + ) + parser.add_argument( + "--seq_length", + type=int, + default=512, + help="Sequence length for training" + ) + parser.add_argument( + "--steps", + type=int, + default=1, + help="Total training steps" + ) + parser.add_argument( + "--learning_rate", + type=float, + default=2e-5, + help="Learning rate for optimizer" + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Gradient accumulation steps" + ) + parser.add_argument( + "--activation_checkpointing", + action="store_true", + help="Enable gradient checkpointing" + ) + parser.add_argument( + "--dataset_name", + type=str, + default="timdettmers/openassistant-guanaco", + help="HuggingFace dataset name" + ) + parser.add_argument( + "--num_layers", + type=int, + default=None, + help="Number of transformer layers (None means use full model)" + ) + + # Compilation arguments + parser.add_argument( + "--compile", + type=str, + default="autosp", + choices=["eager", "compile", "autosp", "ringattn"], + help="Compilation mode: eager (no compilation), compile (torch.compile), autosp (AutoSP), ringattn (ring attention)" + ) + parser.add_argument( + "--backend", + type=str, + default="inductor", + help="Backend compiler (e.g., inductor, cudagraph)" + ) + + parser.add_argument( + "--deterministic", + action="store_true", + help="Enable deterministic mode for reproducibility" + ) + + parser.add_argument( + "--print_interval", + type=int, + default=1, + help="Interval for printing metrics" + ) + + parser.add_argument( + "--sp_size", + type=int, + default=2, + help="Sequence parallel size" + ) + parser.add_argument( + "--dp_size", + type=int, + default=1, + help="Data parallel size" + ) + + return parser.parse_args() + +def validate_args(args): + valid_compile_modes = ["eager", "compile", "autosp", "ringattn"] + if args.compile not in valid_compile_modes: + raise ValueError( + f"Invalid compile mode: {args.compile}. " + f"Must be one of {valid_compile_modes}" + ) + + if args.sp_size <= 0 or args.dp_size <= 0: + raise ValueError("sp_size and dp_size must be positive integers") + + if args.seq_length <= 0: + raise ValueError("seq_length must be positive") + + +def print_rank_0(accelerator, *args, **kwargs): + """Print only on main process (rank 0).""" + if accelerator.is_main_process: + print(*args, **kwargs) + + +def main(): + args = get_args() + validate_args(args) + set_seed(12) + + if args.deterministic: + enable_full_determinism(12) + from torch._inductor import config + config.fallback_random = True + torch.use_deterministic_algorithms(True) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + device = accelerator.device + assert accelerator.num_processes == args.sp_size * args.dp_size, 'Incorrect dp/sp sizing' + + print_rank_0(accelerator, "\n" + "="*60) + print_rank_0(accelerator, "AutoSP Benchmark Configuration") + print_rank_0(accelerator, "="*60) + print_rank_0(accelerator, f"Model: {args.model_name}") + print_rank_0(accelerator, f"Compile Mode: {args.compile}") + print_rank_0(accelerator, f"Backend: {args.backend}") + print_rank_0(accelerator, f"Sequence Parallel Size: {args.sp_size}") + print_rank_0(accelerator, f"Data Parallel Size: {args.dp_size}") + print_rank_0(accelerator, f"Total Processes: {accelerator.num_processes}") + print_rank_0(accelerator, f"Batch Size: {args.batch_size}") + print_rank_0(accelerator, f"Sequence Length: {args.seq_length}") + print_rank_0(accelerator, f"Num Layers: {args.num_layers if args.num_layers else 'Full model'}") + print_rank_0(accelerator, f"Deterministic: {args.deterministic}") + print_rank_0(accelerator, f"Activation Checkpointing: {args.activation_checkpointing}") + print_rank_0(accelerator, f"Learning Rate: {args.learning_rate}") + print_rank_0(accelerator, f"Gradient Accumulation Steps: {args.gradient_accumulation_steps}") + print_rank_0(accelerator, "="*60 + "\n") + + ## Set sp/dp groups accordingly. + if args.compile in ['compile', 'eager', 'ringattn']: + populate_registry(args.sp_size, args.dp_size) + + print_rank_0(accelerator, "Loading model and tokenizer...") + + model_name = args.model_name + if args.compile == "autosp": + attention_backend = "sdpa" + else: + if args.compile == "eager" or args.compile == "compile": + from transformers.models.llama import modeling_llama + attention_backend = "ulyssess" + modeling_llama.ALL_ATTENTION_FUNCTIONS["ulyssess"] = ulysses_attention_forward + elif args.compile == "ringattn": + from transformers.models.llama import modeling_llama + attention_backend = "ringattn" + modeling_llama.ALL_ATTENTION_FUNCTIONS["ringattn"] = ring_attention_forward + + if args.num_layers is not None: + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + print_rank_0(accelerator, f"num_hidden_layers: {model_config.num_hidden_layers} -> {args.num_layers}") + model_config.num_hidden_layers = args.num_layers + model_config._attn_implementation = attention_backend + model = AutoModelForCausalLM.from_config(model_config, trust_remote_code=True) + else: + model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + model_config._attn_implementation = attention_backend + model = AutoModelForCausalLM.from_pretrained(model_name, config=model_config, trust_remote_code=True) + + if args.activation_checkpointing: + model.gradient_checkpointing_enable() + + print_rank_0(accelerator, "Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + print_rank_0(accelerator, "Loading dataset...") + + g = torch.Generator() + g.manual_seed(12) + dataset = load_dataset('ag_news', split='train[:1%]') + + def tokenize_function(examples): + return tokenizer(examples['text'], padding='max_length', max_length=args.seq_length, truncation=True) + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask']) + + num_replicas_ = args.dp_size + rank_ = accelerator.process_index // args.sp_size + + sampler = DistributedSampler(tokenized_dataset, num_replicas=num_replicas_, rank=rank_, seed=12, shuffle=False) + data_loader = DataLoader(tokenized_dataset, batch_size=args.batch_size, sampler=sampler, num_workers=4, worker_init_fn=seed_worker, generator=g) + + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) + + model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) + print_rank_0(accelerator, f"Model prepared: {model.__class__}") + + if args.compile == "autosp": + print_rank_0(accelerator, f"Running autosp with backend={args.backend}") + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + model.compile(backend=args.backend) + elif args.compile in ["compile", "ringattn"]: + print_rank_0(accelerator, f"Running torch.compile with backend={args.backend}") + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + model = torch.compile(model, backend=args.backend) + else: + print_rank_0(accelerator, f"Running eager") + + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + model_name = args.model_name.split("/")[-1] + exp_name = f"{model_name}_np{accelerator.num_processes}_{args.compile}_" \ + f"B{args.backend}_" \ + f"L{0 if args.num_layers is None else args.num_layers}_" \ + f"bs{args.batch_size}_seq{args.seq_length}_" \ + f"T{timestamp}" + + model.train() + global_step = 0 + print_rank_0(accelerator, f"Using global sequence length: {args.seq_length}") + + os.makedirs("logs", exist_ok=True) + loss_log_file = open(f"logs/loss_{args.compile}_seq{args.seq_length}_rank{accelerator.process_index}.csv", "w") + loss_log_file.write("step,loss\n") + + sp_rank = dist.get_rank() % args.sp_size + for epoch in range(args.num_epochs): + start_iter = time.time() + + for step, batch in enumerate(data_loader): + input_ids = batch['input_ids'].to(device) + B, S = input_ids.shape + + label_ids = input_ids.clone() + position_ids = torch.arange(S, device=device).unsqueeze(0) + attention_mask = batch['attention_mask'].to(device) + + if args.compile == 'autosp': + # prepare inputs for autosp + input_ids, label_ids, position_ids, attention_mask = prepare_autosp_inputs( + input_ids, label_ids, position_ids, attention_mask, seq_dim=1 + ) + else: + chunk_size = S // args.sp_size + start = sp_rank * chunk_size + end = start + chunk_size + input_ids = input_ids[:, start:end] + label_ids = label_ids[:, start:end] + position_ids = position_ids[:, start:end] + + # Store the padding mask to be accessed directly in local attention + set_padding_mask(attention_mask) + + outputs = model( + input_ids=input_ids, + labels=label_ids, + position_ids=position_ids, + attention_mask=attention_mask + ) + loss = outputs.loss + + elapsed_time = time.time() - start_iter + alloc_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3) + peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3) + + if global_step % args.print_interval == 0: + print( + f"[Rank {accelerator.process_index}] Epoch {epoch+1}, Step {global_step}, Loss: {loss.item():.4f}, " + f"Time: {elapsed_time:.2f}s, " + f"Alloc Mem: {alloc_mem_gb:.2f} GB, " + f"Peak Mem: {peak_mem_gb:.2f} GB" + ) + + accelerator.backward(loss) + + loss_log_file.write(f"{global_step},{loss.item()}\n") + loss_log_file.flush() + + global_step += 1 + if global_step > args.steps: + break + +if __name__ == "__main__": + torch._dynamo.config.accumulated_cache_size_limit = 256 + torch._dynamo.config.cache_size_limit = 128 + torch._dynamo.config.optimize_ddp = False + + main() + diff --git a/benchmarks/autosp/run_autosp.sh b/benchmarks/autosp/run_autosp.sh new file mode 100755 index 000000000..51697004b --- /dev/null +++ b/benchmarks/autosp/run_autosp.sh @@ -0,0 +1,120 @@ +#!/bin/bash + +# Default parameters +MODEL="meta-llama/Llama-2-7b-chat-hf" +COMPILE="eager" +BACKEND="inductor" +SP_SIZE=2 +DP_SIZE=1 +BATCH_SIZE=1 +SEQ_LENGTH=64 +EXTRA_OPTS="" + +while [[ $# -gt 0 ]]; do + case $1 in + --host-ip) + HOST_IP="$2" + shift 2 + ;; + --model) + MODEL="$2" + shift 2 + ;; + --compile) + COMPILE="$2" + shift 2 + ;; + --backend) + BACKEND="$2" + shift 2 + ;; + --batch-size) + BATCH_SIZE="$2" + shift 2 + ;; + --seq-length) + SEQ_LENGTH="$2" + shift 2 + ;; + --sp-size) + SP_SIZE="$2" + shift 2 + ;; + --dp-size) + DP_SIZE="$2" + shift 2 + ;; + --num-layers) + EXTRA_OPTS="${EXTRA_OPTS} --num_layers $2" + shift 2 + ;; + *) + EXTRA_OPTS="${EXTRA_OPTS} $1" + shift + ;; + esac +done + +if [[ "$COMPILE" != "eager" && "$COMPILE" != "compile" && "$COMPILE" != "autosp" && "$COMPILE" != "ringattn" ]]; then + echo "Invalid compile mode: $COMPILE. Choose from eager, compile, autosp, ringattn." + exit 1 +fi + +if [[ -z "${HOST_IP}" ]]; then + HOST_IP=$(hostname -i | awk '{print $1}') +fi + +PORT=$(python3 -c "import socket; s = socket.socket(); s.bind(('', 0)); print(s.getsockname()[1]); s.close()") + +NUM_PROCESSES=$((SP_SIZE * DP_SIZE)) + +CONFIG_FILE="configs/torchcompile_config.yaml" +if [ "${COMPILE}" == "autosp" ]; then + CONFIG_FILE="configs/autosp_config.yaml" +fi + +mkdir -p logs + +# Generate timestamp for log file +TIMESTAMP=$(date +"%Y%m%d_%H%M%S") +LOG_FILE=logs/log_${COMPILE}_sp${SP_SIZE}_dp${DP_SIZE}_seq${SEQ_LENGTH}_${TIMESTAMP}.log + +# Print configuration +echo "" +echo "================================================================" +echo "Configuration" +echo "================================================================" +echo "HOST_IP: ${HOST_IP}" +echo "PORT: ${PORT}" +echo "NUM_PROCESSES: ${NUM_PROCESSES}" +echo "MODEL: ${MODEL}" +echo "COMPILE: ${COMPILE}" +echo "BACKEND: ${BACKEND}" +echo "SP_SIZE: ${SP_SIZE}" +echo "DP_SIZE: ${DP_SIZE}" +echo "BATCH_SIZE: ${BATCH_SIZE}" +echo "SEQ_LENGTH: ${SEQ_LENGTH}" +echo "LOG_FILE: ${LOG_FILE}" +echo "================================================================" +echo "" + +export NCCL_DEBUG=WARN + +# Launch training +accelerate launch \ + --main_process_ip ${HOST_IP} \ + --main_process_port ${PORT} \ + --num_machines 1 \ + --num_processes ${NUM_PROCESSES} \ + --machine_rank 0 \ + --config_file ${CONFIG_FILE} \ + run.py \ + --model_name "${MODEL}" \ + --batch_size ${BATCH_SIZE} \ + --seq_length ${SEQ_LENGTH} \ + --sp_size ${SP_SIZE} \ + --dp_size ${DP_SIZE} \ + --backend ${BACKEND} \ + --compile ${COMPILE} \ + ${EXTRA_OPTS} \ + 2>&1 | tee ${LOG_FILE} diff --git a/benchmarks/autosp/sp_dp_registry.py b/benchmarks/autosp/sp_dp_registry.py new file mode 100644 index 000000000..4fc1913f1 --- /dev/null +++ b/benchmarks/autosp/sp_dp_registry.py @@ -0,0 +1,45 @@ +import torch +import torch.distributed as dist + +GROUP_REGISTRY = {} # int -> dist.ProcessGroup + +def register_groups(groups): + """groups: List[List[int]], e.g. [[0,1],[2,3]]""" + for gid, ranks in enumerate(groups): + if gid not in GROUP_REGISTRY: + GROUP_REGISTRY[gid] = dist.new_group(ranks) + +def get_group(gid: int): + return GROUP_REGISTRY[gid] if gid is not None else dist.group.WORLD + +def get_registry(): + return GROUP_REGISTRY + +def is_setup(): + return GROUP_REGISTRY['is_reg'] if 'is_reg' in GROUP_REGISTRY else False + +def sp_size(): + assert 'SP_SIZE' in GROUP_REGISTRY, 'SP_SIZE not init properly.' + + return GROUP_REGISTRY['SP_SIZE'] + +def dp_size(): + assert 'DP_SIZE' in GROUP_REGISTRY, 'DP_SIZE not init properly' + + return GROUP_REGISTRY['DP_SIZE'] + +def populate_registry(SP_SIZE, DP_SIZE): + ## We register in the run_acc_lm.py file for baselines to reduce code-duplication. + ## Else the registration happens within the SP compiler pass within deepspeed. + group_listing = [] + offset = 0 + for _ in range(DP_SIZE): + group_listing.append([i + offset for i in range(SP_SIZE)]) + offset += SP_SIZE + + register_groups(group_listing) + + ## Extraneous metadata required for proper instatiation. ## + GROUP_REGISTRY['SP_SIZE'] = SP_SIZE + GROUP_REGISTRY['DP_SIZE'] = DP_SIZE + GROUP_REGISTRY['is_reg'] = True From 0154a46aa6a2e2632df865c110e3d7ed13646e84 Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Mon, 23 Feb 2026 01:00:03 -0600 Subject: [PATCH 2/3] fix lint --- benchmarks/autosp/configs/autosp_config.json | 2 +- benchmarks/autosp/configs/torchcompile_config.yaml | 2 +- benchmarks/autosp/sp_dp_registry.py | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmarks/autosp/configs/autosp_config.json b/benchmarks/autosp/configs/autosp_config.json index 93deb7402..85ed38383 100644 --- a/benchmarks/autosp/configs/autosp_config.json +++ b/benchmarks/autosp/configs/autosp_config.json @@ -16,4 +16,4 @@ "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/benchmarks/autosp/configs/torchcompile_config.yaml b/benchmarks/autosp/configs/torchcompile_config.yaml index cebc281c2..2e35b1185 100644 --- a/benchmarks/autosp/configs/torchcompile_config.yaml +++ b/benchmarks/autosp/configs/torchcompile_config.yaml @@ -13,4 +13,4 @@ same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false -use_cpu: false \ No newline at end of file +use_cpu: false diff --git a/benchmarks/autosp/sp_dp_registry.py b/benchmarks/autosp/sp_dp_registry.py index 4fc1913f1..ebb29d91a 100644 --- a/benchmarks/autosp/sp_dp_registry.py +++ b/benchmarks/autosp/sp_dp_registry.py @@ -29,8 +29,6 @@ def dp_size(): return GROUP_REGISTRY['DP_SIZE'] def populate_registry(SP_SIZE, DP_SIZE): - ## We register in the run_acc_lm.py file for baselines to reduce code-duplication. - ## Else the registration happens within the SP compiler pass within deepspeed. group_listing = [] offset = 0 for _ in range(DP_SIZE): From 2459c001b54a223b636cee639736c938a9138eed Mon Sep 17 00:00:00 2001 From: Neel Dani Date: Mon, 23 Feb 2026 01:43:29 -0600 Subject: [PATCH 3/3] remove pre-requisite section from the readme --- benchmarks/autosp/README.md | 6 ------ 1 file changed, 6 deletions(-) diff --git a/benchmarks/autosp/README.md b/benchmarks/autosp/README.md index 5c7a13eb6..95af0c3bf 100644 --- a/benchmarks/autosp/README.md +++ b/benchmarks/autosp/README.md @@ -2,12 +2,6 @@ Quick start guide to clone and set up the AutoSP repository. -## Prerequisites - -- CUDA 12.8 compatible GPU (recommended) -- Conda installed -- Git - ### Install dependencies