Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/autosp/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*.log
*.pyc
logs
*.
*.pt
52 changes: 52 additions & 0 deletions benchmarks/autosp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# AutoSP Setup Guide

Quick start guide to clone and set up the AutoSP repository.


### Install dependencies

```bash
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
```

```bash
pip install \
transformers==4.50.3 \
tokenizers \
huggingface-hub \
safetensors \
datasets \
accelerate \
scipy \
tqdm \
pyyaml
```

### Install DeepSpeed

```bash
pip install --no-build-isolation git+https://github.com/neeldani/DeepSpeed.git@autosp
```

## Benchmarking

See `benchmarks/autosp/` directory for benchmarking scripts:

```bash
cd benchmarks/autosp
```

#### Run autosp on 2 GPUs
```bash
./run_autosp.sh --compile autosp --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic
```

#### Run eager mode ulysses on 2 GPUs
```bash
./run_autosp.sh --compile eager --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic
```

#### Run torch.compile'd ulysses on 2 GPUs
```bash
./run_autosp.sh --compile compile --batch-size 1 --seq-length 64 --sp-size 2 --num-layers 1 --steps 1 --deterministic
```
19 changes: 19 additions & 0 deletions benchmarks/autosp/configs/autosp_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{

"bf16": {
"enabled": true
},

"zero_optimization":{
"stage": 0
},
"compile": {
"deepcompile": true
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
16 changes: 16 additions & 0 deletions benchmarks/autosp/configs/autosp_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
deepspeed_config_file: configs/autosp_config.json
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
14 changes: 14 additions & 0 deletions benchmarks/autosp/configs/torchcompile_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"bf16": {
"enabled": true
},
"zero_optimization":{
"stage": 0
},
"gradient_accumulation_steps": 1,
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
16 changes: 16 additions & 0 deletions benchmarks/autosp/configs/torchcompile_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
deepspeed_config_file: configs/torchcompile_config.json
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
93 changes: 93 additions & 0 deletions benchmarks/autosp/distributed_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
import torch
import torch.distributed as dist
from deepspeed.sequence.layer import DistributedAttention
from sp_dp_registry import get_group, is_setup, sp_size

#TODO: See if there is a better way to pass the mask
_padding_mask_context = None

def set_padding_mask(mask):
global _padding_mask_context
_padding_mask_context = mask

def get_padding_mask():
global _padding_mask_context
return _padding_mask_context

def ulysses_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
scaling=None,
dropout=0.0,
is_causal=False,
**kwargs,
):
assert is_setup(), 'Incorrectly setup SP/DP Groups.'

gid = dist.get_rank() // sp_size()
group = get_group(gid)

# Ulysses expects (batch, seq, heads, dim)
# HF standard provides (batch, heads, seq, dim)
q = query_states.transpose(1, 2).contiguous()
k = key_states.transpose(1, 2).contiguous()
v = value_states.transpose(1, 2).contiguous()

if not hasattr(self, "ulysses_engine"):
self.ulysses_engine = DistributedAttention(
sdpa_wrapper,
group,
scatter_idx=2, # Shard heads
gather_idx=1 # Gather sequences
)

attn_output = self.ulysses_engine(
q, k, v,
batch_dim_idx=0,
attn_mask=None,
dropout_p=dropout,
is_causal=is_causal,
scale=scaling
)

return attn_output, None

def sdpa_wrapper(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None):
# Permute from [b, s, n, h] to [b, n, s, h] for SDPA
q = query.permute(0, 2, 1, 3).contiguous()
k = key.permute(0, 2, 1, 3).contiguous()
v = value.permute(0, 2, 1, 3).contiguous()

# Create the attention mask from padding mask + causal mask
padding_mask = get_padding_mask()
combined_mask = None

if padding_mask is not None:
B, S = padding_mask.shape # [B, S]
device = padding_mask.device

causal_mask = torch.tril(torch.ones(S, S, device=device, dtype=torch.bool))
padding_mask_bool = (padding_mask != 0).unsqueeze(1)
causal_expanded = causal_mask.unsqueeze(0)
combined_mask = causal_expanded & padding_mask_bool
combined_mask = combined_mask.unsqueeze(1)

elif is_causal:
pass

output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=combined_mask,
dropout_p=dropout_p,
is_causal=(combined_mask is None and is_causal),
scale=scale,
enable_gqa=False
)

# Permute back from [b, n, s, h] to [b, s, n, h] for all-to-all on output
output = output.permute(0, 2, 1, 3).contiguous()
return output
Loading