diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 3ecf8db7ef06..63490a1f40fd 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -480,6 +480,15 @@ def __init__(self, config: GptOssConfig): self.norm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = GptOssRotaryEmbedding(config=config) self.gradient_checkpointing = False + if ( + "flash" in config._attn_implementation + and config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {config._attn_implementation}. " + f"Only kernels-community/vllm-flash-attn3 is supported." + ) # Initialize weights and apply final processing self.post_init() @@ -499,7 +508,15 @@ def forward( ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - + if ( + "flash" in self.config._attn_implementation + and self.config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {self.config._attn_implementation}. " + f"Only kernels-community/vllm-flash-attn3 is supported." + ) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 26587ba4e2a8..06c12ef329c3 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -381,6 +381,20 @@ def _init_weights(self, module): class GptOssModel(MixtralModel): _no_split_modules = ["GptOssDecoderLayer"] + def __init__(self, config: GptOssConfig): + super().__init__(config) + + if ( + "flash" in config._attn_implementation + and config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {config._attn_implementation}. " + f"Only 'kernels-community/vllm-flash-attn3' is supported." + ) + + @check_model_inputs @auto_docstring def forward( @@ -396,6 +410,17 @@ def forward( ) -> MoeModelOutputWithPast: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if ( + "flash" in self.config._attn_implementation + and self.config._attn_implementation != "kernels-community/vllm-flash-attn3" + ): + raise ValueError( + f"GPT-OSS model does not support the specified " + f"flash attention implementation: {self.config._attn_implementation}. " + f"Only 'kernels-community/vllm-flash-attn3' is supported." + ) + if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config)