Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions src/transformers/models/gpt_oss/modeling_gpt_oss.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird, there is a lot of changes like __init__ -> _init_ which shouldn't happen. Maybe some ruff version mismatch? It's weird

Copy link
Author

@pushkar-hue pushkar-hue Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu Thank you so much for your patience. There was some issue with modeling file from my side. I have updated the conditions to allow kernels-community/vllm-flash-attn3, I believe this is the one you were referring to. I have also added a check inside forward to catch any runtime configuration changes. Let me know if I'm missing something.

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ... import initialization as init
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernelized_func
from ...integrations.hub_kernels import use_kernel_forward_from_hub
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import (
Expand All @@ -40,7 +41,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from .configuration_gpt_oss import GptOssConfig


Expand Down Expand Up @@ -235,7 +236,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = freqs
cos = emb.cos() * self.attention_scaling
Expand Down Expand Up @@ -301,12 +302,13 @@ def eager_attention_forward(
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype)
scores = probs[..., :-1] # we drop the sink here
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training)
attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights


@use_kernelized_func(apply_rotary_pos_emb)
class GptOssAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -332,7 +334,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int):
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
self.rotary_fn = apply_rotary_pos_emb
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
self.sinks = nn.Parameter(torch.empty(config.num_attention_heads))

Expand All @@ -343,7 +344,6 @@ def forward(
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
Expand Down Expand Up @@ -373,7 +373,6 @@ def forward(
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
position_ids=position_ids,
s_aux=self.sinks, # diff with Llama
**kwargs,
)
Expand Down Expand Up @@ -481,6 +480,15 @@ def __init__(self, config: GptOssConfig):
self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = GptOssRotaryEmbedding(config=config)
self.gradient_checkpointing = False
if (
"flash" in config._attn_implementation
and config._attn_implementation != "kernels-community/vllm-flash-attn3"
):
raise ValueError(
f"GPT-OSS model does not support the specified "
f"flash attention implementation: {config._attn_implementation}. "
f"Only kernels-community/vllm-flash-attn3 is supported."
)

# Initialize weights and apply final processing
self.post_init()
Expand All @@ -500,7 +508,15 @@ def forward(
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if (
"flash" in self.config._attn_implementation
and self.config._attn_implementation != "kernels-community/vllm-flash-attn3"
):
raise ValueError(
f"GPT-OSS model does not support the specified "
f"flash attention implementation: {self.config._attn_implementation}. "
f"Only kernels-community/vllm-flash-attn3 is supported."
)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)

Expand Down
25 changes: 25 additions & 0 deletions src/transformers/models/gpt_oss/modular_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,20 @@ def _init_weights(self, module):
class GptOssModel(MixtralModel):
_no_split_modules = ["GptOssDecoderLayer"]

def __init__(self, config: GptOssConfig):
super().__init__(config)

if (
"flash" in config._attn_implementation
and config._attn_implementation != "kernels-community/vllm-flash-attn3"
):
raise ValueError(
f"GPT-OSS model does not support the specified "
f"flash attention implementation: {config._attn_implementation}. "
f"Only 'kernels-community/vllm-flash-attn3' is supported."
)


@check_model_inputs
@auto_docstring
def forward(
Expand All @@ -398,6 +412,17 @@ def forward(
) -> MoeModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if (
"flash" in self.config._attn_implementation
and self.config._attn_implementation != "kernels-community/vllm-flash-attn3"
):
raise ValueError(
f"GPT-OSS model does not support the specified "
f"flash attention implementation: {self.config._attn_implementation}. "
f"Only 'kernels-community/vllm-flash-attn3' is supported."
)


if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
Expand Down