diff --git a/src/transformers/models/ministral3/convert_ministral3_weights_to_hf.py b/src/transformers/models/ministral3/convert_ministral3_weights_to_hf.py index 29b267c888e2..baf6787204dd 100644 --- a/src/transformers/models/ministral3/convert_ministral3_weights_to_hf.py +++ b/src/transformers/models/ministral3/convert_ministral3_weights_to_hf.py @@ -35,46 +35,48 @@ # fmt: off -STATE_DICT_MAPPING = { - # Text model keys - r"^output.weight": r"lm_head.weight", - r"^norm.weight": r"model.language_model.norm.weight", - r"^tok_embeddings.weight": r"model.language_model.embed_tokens.weight", - r"^layers.(\d+).attention_norm.weight": r"model.language_model.layers.\1.input_layernorm.weight", - r"^layers.(\d+).ffn_norm.weight": r"model.language_model.layers.\1.post_attention_layernorm.weight", - r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight", - r"^layers.(\d+).feed_forward.w1.weight": r"model.language_model.layers.\1.mlp.gate_proj.weight", - r"^layers.(\d+).feed_forward.w2.weight": r"model.language_model.layers.\1.mlp.down_proj.weight", - r"^layers.(\d+).feed_forward.w3.weight": r"model.language_model.layers.\1.mlp.up_proj.weight", - r"^layers.(\d+).attention.w(q|k|v|o).qscale_act": r"model.language_model.layers.\1.self_attn.\2_proj.activation_scale", - r"^layers.(\d+).feed_forward.w1.qscale_act": r"model.language_model.layers.\1.mlp.gate_proj.activation_scale", - r"^layers.(\d+).feed_forward.w2.qscale_act": r"model.language_model.layers.\1.mlp.down_proj.activation_scale", - r"^layers.(\d+).feed_forward.w3.qscale_act": r"model.language_model.layers.\1.mlp.up_proj.activation_scale", - r"^layers.(\d+).attention.w(q|k|v|o).qscale_weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight_scale_inv", - r"^layers.(\d+).feed_forward.w1.qscale_weight": r"model.language_model.layers.\1.mlp.gate_proj.weight_scale_inv", - r"^layers.(\d+).feed_forward.w2.qscale_weight": r"model.language_model.layers.\1.mlp.down_proj.weight_scale_inv", - r"^layers.(\d+).feed_forward.w3.qscale_weight": r"model.language_model.layers.\1.mlp.up_proj.weight_scale_inv", - - # Vision model keys - r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight", - r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight", - r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight", - r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", - r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", - r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", - r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1", - r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2", - r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight", - r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight", - r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight", - r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight", -} +def get_sd_mapping(has_vision: bool) -> dict: + model_key = "model.language_model" if has_vision else "model" + return { + # Text model keys + r"^output.weight": r"lm_head.weight", + r"^norm.weight": rf"{model_key}.norm.weight", + r"^tok_embeddings.weight": rf"{model_key}.embed_tokens.weight", + r"^layers.(\d+).attention_norm.weight": rf"{model_key}.layers.\1.input_layernorm.weight", + r"^layers.(\d+).ffn_norm.weight": rf"{model_key}.layers.\1.post_attention_layernorm.weight", + r"^layers.(\d+).attention.w(q|k|v|o).weight": rf"{model_key}.layers.\1.self_attn.\2_proj.weight", + r"^layers.(\d+).feed_forward.w1.weight": rf"{model_key}.layers.\1.mlp.gate_proj.weight", + r"^layers.(\d+).feed_forward.w2.weight": rf"{model_key}.layers.\1.mlp.down_proj.weight", + r"^layers.(\d+).feed_forward.w3.weight": rf"{model_key}.layers.\1.mlp.up_proj.weight", + r"^layers.(\d+).attention.w(q|k|v|o).qscale_act": rf"{model_key}.layers.\1.self_attn.\2_proj.activation_scale", + r"^layers.(\d+).feed_forward.w1.qscale_act": rf"{model_key}.layers.\1.mlp.gate_proj.activation_scale", + r"^layers.(\d+).feed_forward.w2.qscale_act": rf"{model_key}.layers.\1.mlp.down_proj.activation_scale", + r"^layers.(\d+).feed_forward.w3.qscale_act": rf"{model_key}.layers.\1.mlp.up_proj.activation_scale", + r"^layers.(\d+).attention.w(q|k|v|o).qscale_weight": rf"{model_key}.layers.\1.self_attn.\2_proj.weight_scale_inv", + r"^layers.(\d+).feed_forward.w1.qscale_weight": rf"{model_key}.layers.\1.mlp.gate_proj.weight_scale_inv", + r"^layers.(\d+).feed_forward.w2.qscale_weight": rf"{model_key}.layers.\1.mlp.down_proj.weight_scale_inv", + r"^layers.(\d+).feed_forward.w3.qscale_weight": rf"{model_key}.layers.\1.mlp.up_proj.weight_scale_inv", + + # Vision model keys + r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight", + r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight", + r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", + r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", + r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1", + r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2", + r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight", + r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight", + r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight", + r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight", + } # fmt: on -def map_old_key_to_new(old_key): +def map_old_key_to_new(old_key, mapping): """Map of a key of the original state dict to the equivalent key in HF format""" - for pattern, replacement in STATE_DICT_MAPPING.items(): + for pattern, replacement in mapping.items(): new_key, n_replace = re.subn(pattern, replacement, old_key) # Early exit of the loop if n_replace > 0: @@ -100,11 +102,13 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config): """Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case).""" new_dict = {} + is_vision = isinstance(config, Mistral3Config) + mapping = get_sd_mapping(is_vision) for old_key, tensor in original_state_dict.items(): if "fake_quantizer" in old_key: continue - new_key = map_old_key_to_new(old_key) + new_key = map_old_key_to_new(old_key, mapping) if "vision" in old_key: num_attention_heads = config.vision_config.num_attention_heads @@ -114,10 +118,11 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config): key_value_dim = head_dim * num_attention_heads query_dim = head_dim * num_attention_heads else: - num_attention_heads = config.text_config.num_attention_heads - hidden_size = config.text_config.hidden_size - head_dim = config.text_config.head_dim - num_key_value_heads = config.text_config.num_key_value_heads + text_config = config.text_config if is_vision else config + num_attention_heads = text_config.num_attention_heads + hidden_size = text_config.hidden_size + head_dim = text_config.head_dim + num_key_value_heads = text_config.num_key_value_heads key_value_dim = head_dim * num_key_value_heads query_dim = head_dim * num_attention_heads @@ -130,8 +135,11 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config): return new_dict -def convert_config(original_config: dict, max_position_embeddings: int = 262144): +def convert_config(original_config: dict, max_position_embeddings: int = 262144, is_vision: bool = True): original_vision_config = original_config.pop("vision_encoder", None) + assert is_vision == (original_vision_config is not None), ( + f"is_vision={is_vision} but original_vision_config={original_vision_config}" + ) original_text_config = original_config # Text config @@ -159,9 +167,9 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], "beta_fast": float(original_config["yarn"]["beta"]), "beta_slow": float(original_config["yarn"]["alpha"]), - "mscale_all_dim": 1.0, + "mscale_all_dim": 1.0 if is_vision else 0.0, "mscale": 1.0, - "llama_4_scaling_beta": original_config["llama_4_scaling"]["beta"], + "llama_4_scaling_beta": original_config.get("llama_4_scaling", {}).get("beta", 0), } # These are not always defined depending on `params.json` @@ -173,11 +181,25 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) if new_text_config_kwargs["sliding_window"] is not None: new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"]) - new_text_config = Ministral3Config(**new_text_config_kwargs) + def get_maybe_quant_config() -> dict: + kwargs = {} + if original_config.get("quantization", {}).get("qformat_weight") == "fp8_e4m3": + assert original_config["quantization"]["qscheme_act"] == "TENSOR" + quantization_config = { + "activation_scheme": "static", + "modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector", "lm_head"], + "quant_method": "fp8", + "weight_block_size": None, + } + kwargs["quantization_config"] = AutoQuantizationConfig.from_dict(quantization_config) + return kwargs # No vision if original_vision_config is None: + new_text_config = Ministral3Config(**new_text_config_kwargs, **get_maybe_quant_config()) return new_text_config + else: + new_text_config = Ministral3Config(**new_text_config_kwargs) # Vision config new_vision_config = original_vision_config @@ -191,17 +213,6 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) _ = new_vision_config.pop("max_image_size") new_vision_config = PixtralVisionConfig(hidden_act="silu", **new_vision_config) - kwargs = {} - if original_config.get("quantization", {}).get("qformat_weight") == "fp8_e4m3": - assert original_config["quantization"]["qscheme_act"] == "TENSOR" - quantization_config = { - "activation_scheme": "static", - "modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector"], - "quant_method": "fp8", - "weight_block_size": None, - } - kwargs["quantization_config"] = AutoQuantizationConfig.from_dict(quantization_config) - new_config = Mistral3Config( vision_config=new_vision_config, text_config=new_text_config, @@ -209,7 +220,7 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) image_token_id=image_token_id, spatial_merge_size=spatial_merge_size, vision_feature_layer=-1, - **kwargs, + **get_maybe_quant_config(), ) return new_config @@ -218,7 +229,8 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd """Convert the model and save it (this implicitly save the config as well).""" params = read_json(os.path.join(input_dir, "params.json")) - config = convert_config(params, max_position_embeddings) + is_vision = params.get("vision_encoder") is not None + config = convert_config(params, max_position_embeddings, is_vision) full_state_dict = {} # The model may be split between different files, but a single nn.Module is always fully present in a single file @@ -228,8 +240,10 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd new_dict = convert_state_dict(original_state_dict, config) full_state_dict.update(new_dict) - if config.text_config.tie_word_embeddings: - full_state_dict["lm_head.weight"] = full_state_dict["model.language_model.embed_tokens.weight"] + text_config = config.text_config if is_vision else config + if text_config.tie_word_embeddings: + model_key = "model.language_model" if is_vision else "model" + full_state_dict["lm_head.weight"] = full_state_dict[f"{model_key}.embed_tokens.weight"] # Load weights into model and resave them with torch.device("meta"):