Skip to content

Commit 3f3cae7

Browse files
CyrilvallezWauplinArthurZucker
authored
🚨🚨 [saving] Default to 50GB shards, and remove non-safe serialization (#42734)
* switch * remove now useless save_function * a bit more involved than i thought * all converters * fix * pretty print * fix * trainer * update musicgen.md docs * marc comments * doc and last missed instances * CI --------- Co-authored-by: Wauplin <[email protected]> Co-authored-by: Arthur <[email protected]>
1 parent 73a13f8 commit 3f3cae7

File tree

81 files changed

+404
-683
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

81 files changed

+404
-683
lines changed

docs/source/en/model_doc/musicgen.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ This model was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-ga
6060

6161
```bash
6262
python src/transformers/models/musicgen/convert_musicgen_transformers.py \
63-
--checkpoint small --pytorch_dump_folder /output/path --safe_serialization
63+
--checkpoint small --pytorch_dump_folder /output/path
6464
```
6565

6666
## Generation

docs/source/en/quantization/torchao.md

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -639,30 +639,35 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
639639

640640
## Serialization
641641

642-
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.
643-
644-
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
642+
Saving the quantized model with `save_pretrained` (in [safetensors](https://huggingface.co/docs/safetensors/en/index) format) is only supported for torchao >= v0.15. For any version below, it is only possible to manually save as unsafe `.bin` checkpoints with [torch.save](https://docs.pytorch.org/docs/stable/generated/torch.save.html).
645643

646644
<hfoptions id="serialization-examples">
647645
<hfoption id="save-locally">
648646

649647
```py
650-
# don't serialize model with Safetensors
648+
# torchao >= 0.15
651649
output_dir = "llama3-8b-int4wo-128"
652-
quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False)
650+
quantized_model.save_pretrained("llama3-8b-int4wo-128")
653651
```
654652

655653
</hfoption>
656654
<hfoption id="push-to-huggingface-hub">
657655

658656
```py
659-
# don't serialize model with Safetensors
657+
# torchao >= 0.15
660658
USER_ID = "your_huggingface_user_id"
661659
REPO_ID = "llama3-8b-int4wo-128"
662-
quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serialization=False)
660+
quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
663661
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
664662
```
665663

664+
665+
```py
666+
# torchao < 0.15 -> unsafe serialization
667+
filename = "llama3-8b-int4wo-128/pytorch_model.bin"
668+
torch.save(quantized_model.state_dict(), filename)
669+
```
670+
666671
</hfoption>
667672
</hfoptions>
668673

@@ -687,7 +692,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
687692
)
688693
# save the quantized model
689694
output_dir = "llama-3.1-8b-torchao-int8"
690-
quantized_model.save_pretrained(output_dir, safe_serialization=False)
695+
quantized_model.save_pretrained(output_dir)
691696

692697
# reload the quantized model
693698
reloaded_model = AutoModelForCausalLM.from_pretrained(
@@ -724,7 +729,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
724729
)
725730
# save the quantized model
726731
output_dir = "llama-3.1-8b-torchao-int4-cpu"
727-
quantized_model.save_pretrained(output_dir, safe_serialization=False)
732+
quantized_model.save_pretrained(output_dir)
728733

729734
# reload the quantized model
730735
reloaded_model = AutoModelForCausalLM.from_pretrained(

examples/3D_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def collate_fn(batch):
340340
else:
341341
# Fallback to regular save for non-distributed case
342342
save_dir = "test_model_nondist"
343-
model.save_pretrained(save_dir, safe_serialization=False)
343+
model.save_pretrained(save_dir)
344344
tokenizer.save_pretrained(save_dir) # Save tokenizer too
345345
logger.info(f"Saved model to {save_dir}")
346346

examples/pytorch/3d_parallel_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def collate_fn(batch):
458458
else:
459459
# Fallback to regular save for non-distributed case
460460
save_dir = "test_model_nondist"
461-
model.save_pretrained(save_dir, safe_serialization=False)
461+
model.save_pretrained(save_dir)
462462
tokenizer.save_pretrained(save_dir) # Save tokenizer too
463463
logger.info(f"Saved model to {save_dir}")
464464

examples/quantization/custom_quantization_int8_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def _process_model_after_weight_loading(self, model, **kwargs):
216216
"""
217217
return True
218218

219-
def is_serializable(self, safe_serialization=None):
219+
def is_serializable(self):
220220
return True
221221

222222
@property

src/transformers/modeling_utils.py

Lines changed: 77 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@
9393
from .safetensors_conversion import auto_conversion
9494
from .utils import (
9595
ADAPTER_SAFE_WEIGHTS_NAME,
96-
ADAPTER_WEIGHTS_NAME,
9796
DUMMY_INPUTS,
9897
SAFE_WEIGHTS_INDEX_NAME,
9998
SAFE_WEIGHTS_NAME,
@@ -551,8 +550,7 @@ def _get_resolved_checkpoint_files(
551550
raise OSError(
552551
f"{pretrained_model_name_or_path} does not appear to have a file named"
553552
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
554-
"and thus cannot be loaded with `safetensors`. Please make sure that the model has "
555-
"been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
553+
"and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
556554
)
557555
else:
558556
# This repo has no safetensors file of any kind, we switch to PyTorch.
@@ -3009,10 +3007,8 @@ def save_pretrained(
30093007
save_directory: Union[str, os.PathLike],
30103008
is_main_process: bool = True,
30113009
state_dict: Optional[dict] = None,
3012-
save_function: Callable = torch.save,
30133010
push_to_hub: bool = False,
3014-
max_shard_size: Union[int, str] = "5GB",
3015-
safe_serialization: bool = True,
3011+
max_shard_size: Union[int, str] = "50GB",
30163012
variant: Optional[str] = None,
30173013
token: Optional[Union[str, bool]] = None,
30183014
save_peft_format: bool = True,
@@ -3034,18 +3030,13 @@ def save_pretrained(
30343030
The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
30353031
save parts of the model or if special precautions need to be taken when recovering the state dictionary
30363032
of a model (like when using model parallelism).
3037-
save_function (`Callable`):
3038-
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
3039-
need to replace `torch.save` by another method.
30403033
push_to_hub (`bool`, *optional*, defaults to `False`):
30413034
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
30423035
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
30433036
namespace).
3044-
max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
3037+
max_shard_size (`int` or `str`, *optional*, defaults to `"50GB"`):
30453038
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
30463039
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
3047-
We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
3048-
without CPU OOM issues.
30493040
30503041
<Tip warning={true}>
30513042
@@ -3054,10 +3045,8 @@ def save_pretrained(
30543045
30553046
</Tip>
30563047
3057-
safe_serialization (`bool`, *optional*, defaults to `True`):
3058-
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
30593048
variant (`str`, *optional*):
3060-
If specified, weights are saved in the format pytorch_model.<variant>.bin.
3049+
If specified, weights are saved in the format model.<variant>.safetensors.
30613050
token (`str` or `bool`, *optional*):
30623051
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
30633052
the token generated when running `hf auth login` (stored in `~/.huggingface`).
@@ -3079,9 +3068,7 @@ def save_pretrained(
30793068

30803069
hf_quantizer = getattr(self, "hf_quantizer", None)
30813070
quantization_serializable = (
3082-
hf_quantizer is not None
3083-
and isinstance(hf_quantizer, HfQuantizer)
3084-
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
3071+
hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable()
30853072
)
30863073

30873074
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
@@ -3117,7 +3104,7 @@ def save_pretrained(
31173104

31183105
metadata = {}
31193106
if hf_quantizer is not None:
3120-
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
3107+
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self)
31213108
metadata["format"] = "pt"
31223109

31233110
# Only save the model itself if we are using distributed training
@@ -3209,86 +3196,83 @@ def save_pretrained(
32093196
if self._tp_size is not None:
32103197
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
32113198

3212-
if safe_serialization:
3213-
# TODO: fix safe_serialization for tied weights
3214-
# Safetensors does not allow tensor aliasing.
3215-
# We're going to remove aliases before saving
3216-
ptrs = collections.defaultdict(list)
3217-
for name, tensor in state_dict.items():
3218-
if not isinstance(tensor, torch.Tensor):
3219-
# Sometimes in the state_dict we have non-tensor objects.
3220-
# e.g. in bitsandbytes we have some `str` objects in the state_dict
3221-
# In the non-tensor case, fall back to the pointer of the object itself
3222-
ptrs[id(tensor)].append(name)
3223-
3224-
elif tensor.device.type == "meta":
3225-
# In offloaded cases, there may be meta tensors in the state_dict.
3226-
# For these cases, key by the pointer of the original tensor object
3227-
# (state_dict tensors are detached and therefore no longer shared)
3228-
tensor = self.get_parameter(name)
3229-
ptrs[id(tensor)].append(name)
3199+
# Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
3200+
ptrs = collections.defaultdict(list)
3201+
for name, tensor in state_dict.items():
3202+
if not isinstance(tensor, torch.Tensor):
3203+
# Sometimes in the state_dict we have non-tensor objects.
3204+
# e.g. in bitsandbytes we have some `str` objects in the state_dict
3205+
# In the non-tensor case, fall back to the pointer of the object itself
3206+
ptrs[id(tensor)].append(name)
3207+
3208+
elif tensor.device.type == "meta":
3209+
# In offloaded cases, there may be meta tensors in the state_dict.
3210+
# For these cases, key by the pointer of the original tensor object
3211+
# (state_dict tensors are detached and therefore no longer shared)
3212+
tensor = self.get_parameter(name)
3213+
ptrs[id(tensor)].append(name)
32303214

3231-
else:
3232-
ptrs[id_tensor_storage(tensor)].append(name)
3233-
3234-
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3235-
3236-
# Recursively descend to find tied weight keys
3237-
_tied_weights_keys = set(_get_tied_weight_keys(self))
3238-
error_names = []
3239-
to_delete_names = set()
3240-
for names in shared_ptrs.values():
3241-
# Removing the keys which are declared as known duplicates on
3242-
# load. This allows to make sure the name which is kept is consistent.
3243-
if _tied_weights_keys is not None:
3244-
found = 0
3245-
for name in sorted(names):
3246-
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3247-
if matches_pattern and name in state_dict:
3248-
found += 1
3249-
if found < len(names):
3250-
to_delete_names.add(name)
3251-
# We are entering a place where the weights and the transformers configuration do NOT match.
3252-
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3253-
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
3254-
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3255-
for name in disjoint_names:
3256-
state_dict[name] = state_dict[name].clone()
3257-
3258-
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3259-
# If the link between tensors was done at runtime then `from_pretrained` will not get
3260-
# the key back leading to random tensor. A proper warning will be shown
3261-
# during reload (if applicable), but since the file is not necessarily compatible with
3262-
# the config, better show a proper warning.
3263-
shared_names, identical_names = _find_identical(shared_names, state_dict)
3264-
# delete tensors that have identical storage
3265-
for inames in identical_names:
3266-
known = inames.intersection(to_delete_names)
3267-
for name in known:
3268-
del state_dict[name]
3269-
unknown = inames.difference(to_delete_names)
3270-
if len(unknown) > 1:
3271-
error_names.append(unknown)
3272-
3273-
if shared_names:
3274-
error_names.extend(shared_names)
3275-
3276-
if len(error_names) > 0:
3277-
raise RuntimeError(
3278-
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3279-
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3280-
)
3215+
else:
3216+
ptrs[id_tensor_storage(tensor)].append(name)
3217+
3218+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
3219+
3220+
# Recursively descend to find tied weight keys
3221+
_tied_weights_keys = set(_get_tied_weight_keys(self))
3222+
error_names = []
3223+
to_delete_names = set()
3224+
for names in shared_ptrs.values():
3225+
# Removing the keys which are declared as known duplicates on
3226+
# load. This allows to make sure the name which is kept is consistent.
3227+
if _tied_weights_keys is not None:
3228+
found = 0
3229+
for name in sorted(names):
3230+
matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
3231+
if matches_pattern and name in state_dict:
3232+
found += 1
3233+
if found < len(names):
3234+
to_delete_names.add(name)
3235+
# We are entering a place where the weights and the transformers configuration do NOT match.
3236+
shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
3237+
# Those are actually tensor sharing but disjoint from each other, we can safely clone them
3238+
# Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3239+
for name in disjoint_names:
3240+
state_dict[name] = state_dict[name].clone()
3241+
3242+
# When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3243+
# If the link between tensors was done at runtime then `from_pretrained` will not get
3244+
# the key back leading to random tensor. A proper warning will be shown
3245+
# during reload (if applicable), but since the file is not necessarily compatible with
3246+
# the config, better show a proper warning.
3247+
shared_names, identical_names = _find_identical(shared_names, state_dict)
3248+
# delete tensors that have identical storage
3249+
for inames in identical_names:
3250+
known = inames.intersection(to_delete_names)
3251+
for name in known:
3252+
del state_dict[name]
3253+
unknown = inames.difference(to_delete_names)
3254+
if len(unknown) > 1:
3255+
error_names.append(unknown)
3256+
3257+
if shared_names:
3258+
error_names.extend(shared_names)
3259+
3260+
if len(error_names) > 0:
3261+
raise RuntimeError(
3262+
f"The weights trying to be saved contained shared tensors {error_names} which are not properly defined. We found `_tied_weights_keys` to be: {_tied_weights_keys}.\n"
3263+
"This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model.",
3264+
)
32813265

32823266
# Revert all renaming and/or weight operations
32833267
if save_original_format:
32843268
state_dict = revert_weight_conversion(self, state_dict)
32853269

32863270
# Shard the model if it is too big.
32873271
if not _hf_peft_config_loaded:
3288-
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
3272+
weights_name = SAFE_WEIGHTS_NAME
32893273
weights_name = _add_variant(weights_name, variant)
32903274
else:
3291-
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
3275+
weights_name = ADAPTER_SAFE_WEIGHTS_NAME
32923276

32933277
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
32943278
state_dict_split = split_torch_state_dict_into_shards(
@@ -3357,21 +3341,17 @@ def save_pretrained(
33573341
del shard_state_dict
33583342
gc.collect()
33593343

3360-
if safe_serialization:
3361-
# At some point we will need to deal better with save_function (used for TPU and other distributed
3362-
# joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
3363-
# too much before scheduling the next write when its in a different file
3364-
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
3365-
else:
3366-
save_function(shard, os.path.join(save_directory, shard_file))
3344+
# TODO: we should def parallelize this we are otherwise just waiting
3345+
# too much before scheduling the next write when its in a different file
3346+
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
33673347

33683348
del state_dict
33693349

33703350
if index is None:
33713351
path_to_weights = os.path.join(save_directory, weights_name)
33723352
logger.info(f"Model weights saved in {path_to_weights}")
33733353
else:
3374-
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
3354+
save_index_file = SAFE_WEIGHTS_INDEX_NAME
33753355
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
33763356
# Save the index as well
33773357
with open(save_index_file, "w", encoding="utf-8") as f:

0 commit comments

Comments
 (0)