Skip to content
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a817577
add first generation tutorial
patrickvonplaten Jun 21, 2022
cdf3983
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Jun 21, 2022
f4ec985
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Jun 22, 2022
5edcbd9
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Jun 23, 2022
855eed9
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Jun 23, 2022
64eb136
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Jun 30, 2022
677983c
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Jul 18, 2022
232a594
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Aug 17, 2022
fe3a2e5
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Aug 19, 2022
fe9cc03
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Aug 26, 2022
17446d3
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Aug 26, 2022
0590843
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Aug 31, 2022
bb7e53f
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Sep 5, 2022
29988d8
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Sep 26, 2022
5ca8eac
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Oct 11, 2022
52f273d
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Oct 27, 2022
2b84252
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Nov 7, 2022
5787dda
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Nov 8, 2022
b9ac641
Merge branch 'main' of https://github.com/huggingface/transformers
patrickvonplaten Nov 9, 2022
e618a07
jkXMerge branch 'main' of https://github.com/patrickvonplaten/transfo…
patrickvonplaten Dec 5, 2025
dd2abb9
WIP
patrickvonplaten Dec 8, 2025
fd602a1
WIP
patrickvonplaten Dec 8, 2025
c254d5f
WIP
patrickvonplaten Dec 8, 2025
013cbc8
WIP
patrickvonplaten Dec 8, 2025
3635681
WIP
patrickvonplaten Dec 9, 2025
3e829f8
uP-
patrickvonplaten Dec 9, 2025
585afb9
Merge branch 'main' into small_addition_fp8_convert
patrickvonplaten Dec 9, 2025
7187aba
uP-
patrickvonplaten Dec 9, 2025
2d2df3d
WIP
patrickvonplaten Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 74 additions & 60 deletions src/transformers/models/ministral3/convert_ministral3_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,46 +35,48 @@


# fmt: off
STATE_DICT_MAPPING = {
# Text model keys
r"^output.weight": r"lm_head.weight",
r"^norm.weight": r"model.language_model.norm.weight",
r"^tok_embeddings.weight": r"model.language_model.embed_tokens.weight",
r"^layers.(\d+).attention_norm.weight": r"model.language_model.layers.\1.input_layernorm.weight",
r"^layers.(\d+).ffn_norm.weight": r"model.language_model.layers.\1.post_attention_layernorm.weight",
r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight",
r"^layers.(\d+).feed_forward.w1.weight": r"model.language_model.layers.\1.mlp.gate_proj.weight",
r"^layers.(\d+).feed_forward.w2.weight": r"model.language_model.layers.\1.mlp.down_proj.weight",
r"^layers.(\d+).feed_forward.w3.weight": r"model.language_model.layers.\1.mlp.up_proj.weight",
r"^layers.(\d+).attention.w(q|k|v|o).qscale_act": r"model.language_model.layers.\1.self_attn.\2_proj.activation_scale",
r"^layers.(\d+).feed_forward.w1.qscale_act": r"model.language_model.layers.\1.mlp.gate_proj.activation_scale",
r"^layers.(\d+).feed_forward.w2.qscale_act": r"model.language_model.layers.\1.mlp.down_proj.activation_scale",
r"^layers.(\d+).feed_forward.w3.qscale_act": r"model.language_model.layers.\1.mlp.up_proj.activation_scale",
r"^layers.(\d+).attention.w(q|k|v|o).qscale_weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight_scale_inv",
r"^layers.(\d+).feed_forward.w1.qscale_weight": r"model.language_model.layers.\1.mlp.gate_proj.weight_scale_inv",
r"^layers.(\d+).feed_forward.w2.qscale_weight": r"model.language_model.layers.\1.mlp.down_proj.weight_scale_inv",
r"^layers.(\d+).feed_forward.w3.qscale_weight": r"model.language_model.layers.\1.mlp.up_proj.weight_scale_inv",

# Vision model keys
r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight",
r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight",
r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1",
r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2",
r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight",
r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight",
r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight",
r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight",
}
def get_sd_mapping(has_vision: bool) -> dict:
model_key = "model.language_model" if has_vision else "model"
return {
# Text model keys
r"^output.weight": r"lm_head.weight",
r"^norm.weight": rf"{model_key}.norm.weight",
r"^tok_embeddings.weight": rf"{model_key}.embed_tokens.weight",
r"^layers.(\d+).attention_norm.weight": rf"{model_key}.layers.\1.input_layernorm.weight",
r"^layers.(\d+).ffn_norm.weight": rf"{model_key}.layers.\1.post_attention_layernorm.weight",
r"^layers.(\d+).attention.w(q|k|v|o).weight": rf"{model_key}.layers.\1.self_attn.\2_proj.weight",
r"^layers.(\d+).feed_forward.w1.weight": rf"{model_key}.layers.\1.mlp.gate_proj.weight",
r"^layers.(\d+).feed_forward.w2.weight": rf"{model_key}.layers.\1.mlp.down_proj.weight",
r"^layers.(\d+).feed_forward.w3.weight": rf"{model_key}.layers.\1.mlp.up_proj.weight",
r"^layers.(\d+).attention.w(q|k|v|o).qscale_act": rf"{model_key}.layers.\1.self_attn.\2_proj.activation_scale",
r"^layers.(\d+).feed_forward.w1.qscale_act": rf"{model_key}.layers.\1.mlp.gate_proj.activation_scale",
r"^layers.(\d+).feed_forward.w2.qscale_act": rf"{model_key}.layers.\1.mlp.down_proj.activation_scale",
r"^layers.(\d+).feed_forward.w3.qscale_act": rf"{model_key}.layers.\1.mlp.up_proj.activation_scale",
r"^layers.(\d+).attention.w(q|k|v|o).qscale_weight": rf"{model_key}.layers.\1.self_attn.\2_proj.weight_scale_inv",
r"^layers.(\d+).feed_forward.w1.qscale_weight": rf"{model_key}.layers.\1.mlp.gate_proj.weight_scale_inv",
r"^layers.(\d+).feed_forward.w2.qscale_weight": rf"{model_key}.layers.\1.mlp.down_proj.weight_scale_inv",
r"^layers.(\d+).feed_forward.w3.qscale_weight": rf"{model_key}.layers.\1.mlp.up_proj.weight_scale_inv",

# Vision model keys
r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight",
r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight",
r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight",
r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight",
r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1",
r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2",
r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight",
r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight",
r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight",
r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight",
}
# fmt: on


def map_old_key_to_new(old_key):
def map_old_key_to_new(old_key, mapping):
"""Map of a key of the original state dict to the equivalent key in HF format"""
for pattern, replacement in STATE_DICT_MAPPING.items():
for pattern, replacement in mapping.items():
new_key, n_replace = re.subn(pattern, replacement, old_key)
# Early exit of the loop
if n_replace > 0:
Expand All @@ -100,11 +102,13 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config):
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
new_dict = {}

is_vision = isinstance(config, Mistral3Config)
mapping = get_sd_mapping(is_vision)
for old_key, tensor in original_state_dict.items():
if "fake_quantizer" in old_key:
continue

new_key = map_old_key_to_new(old_key)
new_key = map_old_key_to_new(old_key, mapping)

if "vision" in old_key:
num_attention_heads = config.vision_config.num_attention_heads
Expand All @@ -114,10 +118,11 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config):
key_value_dim = head_dim * num_attention_heads
query_dim = head_dim * num_attention_heads
else:
num_attention_heads = config.text_config.num_attention_heads
hidden_size = config.text_config.hidden_size
head_dim = config.text_config.head_dim
num_key_value_heads = config.text_config.num_key_value_heads
text_config = config.text_config if is_vision else config
num_attention_heads = text_config.num_attention_heads
hidden_size = text_config.hidden_size
head_dim = text_config.head_dim
num_key_value_heads = text_config.num_key_value_heads
key_value_dim = head_dim * num_key_value_heads
query_dim = head_dim * num_attention_heads

Expand All @@ -130,8 +135,11 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config):
return new_dict


def convert_config(original_config: dict, max_position_embeddings: int = 262144):
def convert_config(original_config: dict, max_position_embeddings: int = 262144, is_vision: bool = True):
original_vision_config = original_config.pop("vision_encoder", None)
assert is_vision == (original_vision_config is not None), (
f"is_vision={is_vision} but original_vision_config={original_vision_config}"
)
original_text_config = original_config

# Text config
Expand Down Expand Up @@ -159,9 +167,9 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144)
"original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"],
"beta_fast": float(original_config["yarn"]["beta"]),
"beta_slow": float(original_config["yarn"]["alpha"]),
"mscale_all_dim": 1.0,
"mscale_all_dim": 1.0 if is_vision else 0.0,
"mscale": 1.0,
"llama_4_scaling_beta": original_config["llama_4_scaling"]["beta"],
"llama_4_scaling_beta": original_config.get("llama_4_scaling", {}).get("beta", 0),
}

# These are not always defined depending on `params.json`
Expand All @@ -173,11 +181,25 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144)
if new_text_config_kwargs["sliding_window"] is not None:
new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"])

new_text_config = Ministral3Config(**new_text_config_kwargs)
def get_maybe_quant_config() -> dict:
kwargs = {}
if original_config.get("quantization", {}).get("qformat_weight") == "fp8_e4m3":
assert original_config["quantization"]["qscheme_act"] == "TENSOR"
quantization_config = {
"activation_scheme": "static",
"modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector", "lm_head"],
"quant_method": "fp8",
"weight_block_size": None,
}
kwargs["quantization_config"] = AutoQuantizationConfig.from_dict(quantization_config)
return kwargs

# No vision
if original_vision_config is None:
new_text_config = Ministral3Config(**new_text_config_kwargs, **get_maybe_quant_config())
Copy link
Member

@SunMarc SunMarc Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add the quantization config in new_config and not potentially in the new_text_config. In the config.json that you converted, everything looks good

return new_text_config
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to remove the else since we do an early return if the condition evaluates to true?

new_text_config = Ministral3Config(**new_text_config_kwargs)

# Vision config
new_vision_config = original_vision_config
Expand All @@ -191,25 +213,14 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144)
_ = new_vision_config.pop("max_image_size")
new_vision_config = PixtralVisionConfig(hidden_act="silu", **new_vision_config)

kwargs = {}
if original_config.get("quantization", {}).get("qformat_weight") == "fp8_e4m3":
assert original_config["quantization"]["qscheme_act"] == "TENSOR"
quantization_config = {
"activation_scheme": "static",
"modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector"],
"quant_method": "fp8",
"weight_block_size": None,
}
kwargs["quantization_config"] = AutoQuantizationConfig.from_dict(quantization_config)

new_config = Mistral3Config(
vision_config=new_vision_config,
text_config=new_text_config,
multimodal_projector_bias=adapter_bias,
image_token_id=image_token_id,
spatial_merge_size=spatial_merge_size,
vision_feature_layer=-1,
**kwargs,
**get_maybe_quant_config(),
)
return new_config

Expand All @@ -218,7 +229,8 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd
"""Convert the model and save it (this implicitly save the config as well)."""
params = read_json(os.path.join(input_dir, "params.json"))

config = convert_config(params, max_position_embeddings)
is_vision = params.get("vision_encoder") is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to default to "False" instead?

config = convert_config(params, max_position_embeddings, is_vision)

full_state_dict = {}
# The model may be split between different files, but a single nn.Module is always fully present in a single file
Expand All @@ -228,8 +240,10 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd
new_dict = convert_state_dict(original_state_dict, config)
full_state_dict.update(new_dict)

if config.text_config.tie_word_embeddings:
full_state_dict["lm_head.weight"] = full_state_dict["model.language_model.embed_tokens.weight"]
text_config = config.text_config if is_vision else config
if text_config.tie_word_embeddings:
model_key = "model.language_model" if is_vision else "model"
full_state_dict["lm_head.weight"] = full_state_dict[f"{model_key}.embed_tokens.weight"]

# Load weights into model and resave them
with torch.device("meta"):
Expand Down