Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_afmoe import AfmoeConfig


Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/apertus/modeling_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_apertus import ApertusConfig


Expand Down Expand Up @@ -131,7 +131,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_arcee import ArceeConfig


Expand Down Expand Up @@ -138,7 +138,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig

Expand Down Expand Up @@ -675,7 +675,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig

Expand Down Expand Up @@ -250,7 +251,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bitnet/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_bitnet import BitNetConfig


Expand Down Expand Up @@ -326,7 +326,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from .configuration_blt import (
BltConfig,
BltGlobalTransformerConfig,
Expand Down Expand Up @@ -141,7 +141,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/blt/modular_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401
from ..llama.modeling_llama import LlamaRotaryEmbedding
from ..mllama.modeling_mllama import (
Expand Down Expand Up @@ -277,7 +277,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
can_return_tuple,
logging,
)
from ...utils.generic import maybe_autocast
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig


Expand Down Expand Up @@ -122,7 +123,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_cohere import CohereConfig


Expand Down Expand Up @@ -122,7 +122,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modular_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ...utils.generic import maybe_autocast
from ..llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
Expand Down Expand Up @@ -75,7 +76,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_cohere2 import Cohere2Config


Expand Down Expand Up @@ -96,7 +96,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ...utils.generic import maybe_autocast
from ..cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
Expand Down Expand Up @@ -222,7 +223,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/csm/modeling_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_torchdynamo_compiling
from ..auto import AutoModel
from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
Expand Down Expand Up @@ -174,7 +175,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cwm/modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_cwm import CwmConfig


Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_dbrx import DbrxConfig


Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
auto_docstring,
logging,
)
from ...utils.generic import maybe_autocast
from .configuration_decision_transformer import DecisionTransformerConfig


Expand Down Expand Up @@ -141,7 +142,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.autocast(query.device.type, enabled=False):
with maybe_autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_deepseek_v2 import DeepseekV2Config


Expand Down Expand Up @@ -223,7 +223,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
freqs_cis = freqs_cis * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import logging
from ...utils.generic import maybe_autocast
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaDecoderLayer,
Expand Down Expand Up @@ -303,7 +304,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
freqs_cis = freqs_cis * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_deepseek_v3 import DeepseekV3Config


Expand Down Expand Up @@ -110,7 +110,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
Loading
Loading