9393from .safetensors_conversion import auto_conversion
9494from .utils import (
9595 ADAPTER_SAFE_WEIGHTS_NAME ,
96- ADAPTER_WEIGHTS_NAME ,
9796 DUMMY_INPUTS ,
9897 SAFE_WEIGHTS_INDEX_NAME ,
9998 SAFE_WEIGHTS_NAME ,
@@ -551,8 +550,7 @@ def _get_resolved_checkpoint_files(
551550 raise OSError (
552551 f"{ pretrained_model_name_or_path } does not appear to have a file named"
553552 f" { _add_variant (SAFE_WEIGHTS_NAME , variant )} or { _add_variant (SAFE_WEIGHTS_INDEX_NAME , variant )} "
554- "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
555- "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
553+ "and thus cannot be loaded with `safetensors`. Please do not set `use_safetensors=True`."
556554 )
557555 else :
558556 # This repo has no safetensors file of any kind, we switch to PyTorch.
@@ -3009,10 +3007,8 @@ def save_pretrained(
30093007 save_directory : Union [str , os .PathLike ],
30103008 is_main_process : bool = True ,
30113009 state_dict : Optional [dict ] = None ,
3012- save_function : Callable = torch .save ,
30133010 push_to_hub : bool = False ,
3014- max_shard_size : Union [int , str ] = "5GB" ,
3015- safe_serialization : bool = True ,
3011+ max_shard_size : Union [int , str ] = "50GB" ,
30163012 variant : Optional [str ] = None ,
30173013 token : Optional [Union [str , bool ]] = None ,
30183014 save_peft_format : bool = True ,
@@ -3034,18 +3030,13 @@ def save_pretrained(
30343030 The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
30353031 save parts of the model or if special precautions need to be taken when recovering the state dictionary
30363032 of a model (like when using model parallelism).
3037- save_function (`Callable`):
3038- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
3039- need to replace `torch.save` by another method.
30403033 push_to_hub (`bool`, *optional*, defaults to `False`):
30413034 Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
30423035 repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
30433036 namespace).
3044- max_shard_size (`int` or `str`, *optional*, defaults to `"5GB "`):
3037+ max_shard_size (`int` or `str`, *optional*, defaults to `"50GB "`):
30453038 The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
30463039 lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
3047- We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
3048- without CPU OOM issues.
30493040
30503041 <Tip warning={true}>
30513042
@@ -3054,10 +3045,8 @@ def save_pretrained(
30543045
30553046 </Tip>
30563047
3057- safe_serialization (`bool`, *optional*, defaults to `True`):
3058- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
30593048 variant (`str`, *optional*):
3060- If specified, weights are saved in the format pytorch_model .<variant>.bin .
3049+ If specified, weights are saved in the format model .<variant>.safetensors .
30613050 token (`str` or `bool`, *optional*):
30623051 The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
30633052 the token generated when running `hf auth login` (stored in `~/.huggingface`).
@@ -3079,9 +3068,7 @@ def save_pretrained(
30793068
30803069 hf_quantizer = getattr (self , "hf_quantizer" , None )
30813070 quantization_serializable = (
3082- hf_quantizer is not None
3083- and isinstance (hf_quantizer , HfQuantizer )
3084- and hf_quantizer .is_serializable (safe_serialization = safe_serialization )
3071+ hf_quantizer is not None and isinstance (hf_quantizer , HfQuantizer ) and hf_quantizer .is_serializable ()
30853072 )
30863073
30873074 if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable :
@@ -3117,7 +3104,7 @@ def save_pretrained(
31173104
31183105 metadata = {}
31193106 if hf_quantizer is not None :
3120- state_dict , metadata = hf_quantizer .get_state_dict_and_metadata (self , safe_serialization )
3107+ state_dict , metadata = hf_quantizer .get_state_dict_and_metadata (self )
31213108 metadata ["format" ] = "pt"
31223109
31233110 # Only save the model itself if we are using distributed training
@@ -3209,86 +3196,83 @@ def save_pretrained(
32093196 if self ._tp_size is not None :
32103197 state_dict = replace_state_dict_local_with_dtensor (state_dict , self ._tp_plan , self ._device_mesh )
32113198
3212- if safe_serialization :
3213- # TODO: fix safe_serialization for tied weights
3214- # Safetensors does not allow tensor aliasing.
3215- # We're going to remove aliases before saving
3216- ptrs = collections .defaultdict (list )
3217- for name , tensor in state_dict .items ():
3218- if not isinstance (tensor , torch .Tensor ):
3219- # Sometimes in the state_dict we have non-tensor objects.
3220- # e.g. in bitsandbytes we have some `str` objects in the state_dict
3221- # In the non-tensor case, fall back to the pointer of the object itself
3222- ptrs [id (tensor )].append (name )
3223-
3224- elif tensor .device .type == "meta" :
3225- # In offloaded cases, there may be meta tensors in the state_dict.
3226- # For these cases, key by the pointer of the original tensor object
3227- # (state_dict tensors are detached and therefore no longer shared)
3228- tensor = self .get_parameter (name )
3229- ptrs [id (tensor )].append (name )
3199+ # Safetensors does not allow tensor aliasing - we're going to remove aliases before saving
3200+ ptrs = collections .defaultdict (list )
3201+ for name , tensor in state_dict .items ():
3202+ if not isinstance (tensor , torch .Tensor ):
3203+ # Sometimes in the state_dict we have non-tensor objects.
3204+ # e.g. in bitsandbytes we have some `str` objects in the state_dict
3205+ # In the non-tensor case, fall back to the pointer of the object itself
3206+ ptrs [id (tensor )].append (name )
3207+
3208+ elif tensor .device .type == "meta" :
3209+ # In offloaded cases, there may be meta tensors in the state_dict.
3210+ # For these cases, key by the pointer of the original tensor object
3211+ # (state_dict tensors are detached and therefore no longer shared)
3212+ tensor = self .get_parameter (name )
3213+ ptrs [id (tensor )].append (name )
32303214
3231- else :
3232- ptrs [id_tensor_storage (tensor )].append (name )
3233-
3234- shared_ptrs = {ptr : names for ptr , names in ptrs .items () if len (names ) > 1 }
3235-
3236- # Recursively descend to find tied weight keys
3237- _tied_weights_keys = set (_get_tied_weight_keys (self ))
3238- error_names = []
3239- to_delete_names = set ()
3240- for names in shared_ptrs .values ():
3241- # Removing the keys which are declared as known duplicates on
3242- # load. This allows to make sure the name which is kept is consistent.
3243- if _tied_weights_keys is not None :
3244- found = 0
3245- for name in sorted (names ):
3246- matches_pattern = any (re .search (pat , name ) for pat in _tied_weights_keys )
3247- if matches_pattern and name in state_dict :
3248- found += 1
3249- if found < len (names ):
3250- to_delete_names .add (name )
3251- # We are entering a place where the weights and the transformers configuration do NOT match.
3252- shared_names , disjoint_names = _find_disjoint (shared_ptrs .values (), state_dict )
3253- # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3254- # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3255- for name in disjoint_names :
3256- state_dict [name ] = state_dict [name ].clone ()
3257-
3258- # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3259- # If the link between tensors was done at runtime then `from_pretrained` will not get
3260- # the key back leading to random tensor. A proper warning will be shown
3261- # during reload (if applicable), but since the file is not necessarily compatible with
3262- # the config, better show a proper warning.
3263- shared_names , identical_names = _find_identical (shared_names , state_dict )
3264- # delete tensors that have identical storage
3265- for inames in identical_names :
3266- known = inames .intersection (to_delete_names )
3267- for name in known :
3268- del state_dict [name ]
3269- unknown = inames .difference (to_delete_names )
3270- if len (unknown ) > 1 :
3271- error_names .append (unknown )
3272-
3273- if shared_names :
3274- error_names .extend (shared_names )
3275-
3276- if len (error_names ) > 0 :
3277- raise RuntimeError (
3278- f"The weights trying to be saved contained shared tensors { error_names } which are not properly defined. We found `_tied_weights_keys` to be: { _tied_weights_keys } .\n "
3279- "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model." ,
3280- )
3215+ else :
3216+ ptrs [id_tensor_storage (tensor )].append (name )
3217+
3218+ shared_ptrs = {ptr : names for ptr , names in ptrs .items () if len (names ) > 1 }
3219+
3220+ # Recursively descend to find tied weight keys
3221+ _tied_weights_keys = set (_get_tied_weight_keys (self ))
3222+ error_names = []
3223+ to_delete_names = set ()
3224+ for names in shared_ptrs .values ():
3225+ # Removing the keys which are declared as known duplicates on
3226+ # load. This allows to make sure the name which is kept is consistent.
3227+ if _tied_weights_keys is not None :
3228+ found = 0
3229+ for name in sorted (names ):
3230+ matches_pattern = any (re .search (pat , name ) for pat in _tied_weights_keys )
3231+ if matches_pattern and name in state_dict :
3232+ found += 1
3233+ if found < len (names ):
3234+ to_delete_names .add (name )
3235+ # We are entering a place where the weights and the transformers configuration do NOT match.
3236+ shared_names , disjoint_names = _find_disjoint (shared_ptrs .values (), state_dict )
3237+ # Those are actually tensor sharing but disjoint from each other, we can safely clone them
3238+ # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
3239+ for name in disjoint_names :
3240+ state_dict [name ] = state_dict [name ].clone ()
3241+
3242+ # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
3243+ # If the link between tensors was done at runtime then `from_pretrained` will not get
3244+ # the key back leading to random tensor. A proper warning will be shown
3245+ # during reload (if applicable), but since the file is not necessarily compatible with
3246+ # the config, better show a proper warning.
3247+ shared_names , identical_names = _find_identical (shared_names , state_dict )
3248+ # delete tensors that have identical storage
3249+ for inames in identical_names :
3250+ known = inames .intersection (to_delete_names )
3251+ for name in known :
3252+ del state_dict [name ]
3253+ unknown = inames .difference (to_delete_names )
3254+ if len (unknown ) > 1 :
3255+ error_names .append (unknown )
3256+
3257+ if shared_names :
3258+ error_names .extend (shared_names )
3259+
3260+ if len (error_names ) > 0 :
3261+ raise RuntimeError (
3262+ f"The weights trying to be saved contained shared tensors { error_names } which are not properly defined. We found `_tied_weights_keys` to be: { _tied_weights_keys } .\n "
3263+ "This can also just mean that the module's tied weight keys are wrong vs the actual tied weights in the model." ,
3264+ )
32813265
32823266 # Revert all renaming and/or weight operations
32833267 if save_original_format :
32843268 state_dict = revert_weight_conversion (self , state_dict )
32853269
32863270 # Shard the model if it is too big.
32873271 if not _hf_peft_config_loaded :
3288- weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
3272+ weights_name = SAFE_WEIGHTS_NAME
32893273 weights_name = _add_variant (weights_name , variant )
32903274 else :
3291- weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
3275+ weights_name = ADAPTER_SAFE_WEIGHTS_NAME
32923276
32933277 filename_pattern = weights_name .replace (".bin" , "{suffix}.bin" ).replace (".safetensors" , "{suffix}.safetensors" )
32943278 state_dict_split = split_torch_state_dict_into_shards (
@@ -3357,21 +3341,17 @@ def save_pretrained(
33573341 del shard_state_dict
33583342 gc .collect ()
33593343
3360- if safe_serialization :
3361- # At some point we will need to deal better with save_function (used for TPU and other distributed
3362- # joyfulness), but for now this enough. # TODO: we should def parallelize this we are otherwise just waiting
3363- # too much before scheduling the next write when its in a different file
3364- safe_save_file (shard , os .path .join (save_directory , shard_file ), metadata = metadata )
3365- else :
3366- save_function (shard , os .path .join (save_directory , shard_file ))
3344+ # TODO: we should def parallelize this we are otherwise just waiting
3345+ # too much before scheduling the next write when its in a different file
3346+ safe_save_file (shard , os .path .join (save_directory , shard_file ), metadata = metadata )
33673347
33683348 del state_dict
33693349
33703350 if index is None :
33713351 path_to_weights = os .path .join (save_directory , weights_name )
33723352 logger .info (f"Model weights saved in { path_to_weights } " )
33733353 else :
3374- save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
3354+ save_index_file = SAFE_WEIGHTS_INDEX_NAME
33753355 save_index_file = os .path .join (save_directory , _add_variant (save_index_file , variant ))
33763356 # Save the index as well
33773357 with open (save_index_file , "w" , encoding = "utf-8" ) as f :
0 commit comments