From b9e79d1fa77f0db34725954321ee2e77c1eb12f0 Mon Sep 17 00:00:00 2001 From: pushkar-hue Date: Mon, 8 Dec 2025 22:06:04 +0530 Subject: [PATCH 1/3] gpt-oss not working with flash attention --- src/transformers/models/gpt_oss/modular_gpt_oss.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 9432ac28b0fb..29f94d311db8 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -383,6 +383,17 @@ def _init_weights(self, module): class GptOssModel(MixtralModel): _no_split_modules = ["GptOssDecoderLayer"] + def __init__(self, config: GptOssConfig): + super().__init__(config) + + if config._attn_implementation in ["flash_attention_2", "flash_attention_3"]: + raise ValueError( + f"GPT-OSS models do not support {config._attn_implementation} because they utilize " + "attention sinks, which are not currently supported by the standard Flash Attention kernels. " + "Please use 'eager' implementation (attn_implementation='eager') or a custom kernel if available." + ) + + @check_model_inputs @auto_docstring def forward( From 63bf614a8e6e6cae33ab48a8822a782c2c592000 Mon Sep 17 00:00:00 2001 From: pushkar-hue Date: Tue, 9 Dec 2025 19:18:13 +0530 Subject: [PATCH 2/3] changed conditions to allow vllm kernel --- .../models/gpt_oss/modular_gpt_oss.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 29f94d311db8..2b71539c3dab 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -386,14 +386,17 @@ class GptOssModel(MixtralModel): def __init__(self, config: GptOssConfig): super().__init__(config) - if config._attn_implementation in ["flash_attention_2", "flash_attention_3"]: + if ( + "flash" in config._attn_implementation + and config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): raise ValueError( - f"GPT-OSS models do not support {config._attn_implementation} because they utilize " - "attention sinks, which are not currently supported by the standard Flash Attention kernels. " - "Please use 'eager' implementation (attn_implementation='eager') or a custom kernel if available." + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {config._attn_implementation}. " + f"Only '{vllm_fa3_kernel}' is supported." ) - + @check_model_inputs @auto_docstring def forward( @@ -409,6 +412,17 @@ def forward( ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if ( + "flash" in config._attn_implementation + and config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {config._attn_implementation}. " + f"Only '{vllm_fa3_kernel}' is supported." + ) + if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) From c84cc0e1ad0bac16a219c8a876b70a690af5d616 Mon Sep 17 00:00:00 2001 From: pushkar-hue Date: Sat, 13 Dec 2025 22:25:36 +0530 Subject: [PATCH 3/3] changed conditions to allow vllm kernel and modeling file generated --- .../models/gpt_oss/modeling_gpt_oss.py | 30 ++++++++++++++----- .../models/gpt_oss/modular_gpt_oss.py | 10 +++---- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index fac4f8b5680a..63490a1f40fd 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -28,6 +28,7 @@ from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernelized_func from ...integrations.hub_kernels import use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( @@ -40,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_gpt_oss import GptOssConfig @@ -235,7 +236,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = freqs cos = emb.cos() * self.attention_scaling @@ -301,12 +302,13 @@ def eager_attention_forward( combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) scores = probs[..., :-1] # we drop the sink here - attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training).to(value_states.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights +@use_kernelized_func(apply_rotary_pos_emb) class GptOssAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -332,7 +334,6 @@ def __init__(self, config: GptOssConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) - self.rotary_fn = apply_rotary_pos_emb self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) @@ -343,7 +344,6 @@ def forward( attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -373,7 +373,6 @@ def forward( dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, - position_ids=position_ids, s_aux=self.sinks, # diff with Llama **kwargs, ) @@ -481,6 +480,15 @@ def __init__(self, config: GptOssConfig): self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GptOssRotaryEmbedding(config=config) self.gradient_checkpointing = False + if ( + "flash" in config._attn_implementation + and config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {config._attn_implementation}. " + f"Only kernels-community/vllm-flash-attn3 is supported." + ) # Initialize weights and apply final processing self.post_init() @@ -500,7 +508,15 @@ def forward( ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - + if ( + "flash" in self.config._attn_implementation + and self.config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {self.config._attn_implementation}. " + f"Only kernels-community/vllm-flash-attn3 is supported." + ) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 2b71539c3dab..64e58f1452c8 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -393,7 +393,7 @@ def __init__(self, config: GptOssConfig): raise ValueError( f"GPT-OSS model does not support the specified " f"flash attention implementation: {config._attn_implementation}. " - f"Only '{vllm_fa3_kernel}' is supported." + f"Only 'kernels-community/vllm-flash-attn3' is supported." ) @@ -414,13 +414,13 @@ def forward( raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if ( - "flash" in config._attn_implementation - and config._attn_implementation != "kernels-community/vllm-flash-attn3" + "flash" in self.config._attn_implementation + and self.config._attn_implementation != "kernels-community/vllm-flash-attn3" ): raise ValueError( f"GPT-OSS model does not support the specified " - f"flash attention implementation: {config._attn_implementation}. " - f"Only '{vllm_fa3_kernel}' is supported." + f"flash attention implementation: {self.config._attn_implementation}. " + f"Only 'kernels-community/vllm-flash-attn3' is supported." )