-
Notifications
You must be signed in to change notification settings - Fork 31.4k
[Devstral] Make sure FP8 conversion works correctly #42715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a817577
cdf3983
f4ec985
5edcbd9
855eed9
64eb136
677983c
232a594
fe3a2e5
fe9cc03
17446d3
0590843
bb7e53f
29988d8
5ca8eac
52f273d
2b84252
5787dda
b9ac641
e618a07
dd2abb9
fd602a1
c254d5f
013cbc8
3635681
3e829f8
585afb9
7187aba
2d2df3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,46 +35,48 @@ | |
|
|
||
|
|
||
| # fmt: off | ||
| STATE_DICT_MAPPING = { | ||
| # Text model keys | ||
| r"^output.weight": r"lm_head.weight", | ||
| r"^norm.weight": r"model.language_model.norm.weight", | ||
| r"^tok_embeddings.weight": r"model.language_model.embed_tokens.weight", | ||
| r"^layers.(\d+).attention_norm.weight": r"model.language_model.layers.\1.input_layernorm.weight", | ||
| r"^layers.(\d+).ffn_norm.weight": r"model.language_model.layers.\1.post_attention_layernorm.weight", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w1.weight": r"model.language_model.layers.\1.mlp.gate_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w2.weight": r"model.language_model.layers.\1.mlp.down_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w3.weight": r"model.language_model.layers.\1.mlp.up_proj.weight", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).qscale_act": r"model.language_model.layers.\1.self_attn.\2_proj.activation_scale", | ||
| r"^layers.(\d+).feed_forward.w1.qscale_act": r"model.language_model.layers.\1.mlp.gate_proj.activation_scale", | ||
| r"^layers.(\d+).feed_forward.w2.qscale_act": r"model.language_model.layers.\1.mlp.down_proj.activation_scale", | ||
| r"^layers.(\d+).feed_forward.w3.qscale_act": r"model.language_model.layers.\1.mlp.up_proj.activation_scale", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).qscale_weight": r"model.language_model.layers.\1.self_attn.\2_proj.weight_scale_inv", | ||
| r"^layers.(\d+).feed_forward.w1.qscale_weight": r"model.language_model.layers.\1.mlp.gate_proj.weight_scale_inv", | ||
| r"^layers.(\d+).feed_forward.w2.qscale_weight": r"model.language_model.layers.\1.mlp.down_proj.weight_scale_inv", | ||
| r"^layers.(\d+).feed_forward.w3.qscale_weight": r"model.language_model.layers.\1.mlp.up_proj.weight_scale_inv", | ||
|
|
||
| # Vision model keys | ||
| r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", | ||
| r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1", | ||
| r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2", | ||
| r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight", | ||
| r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight", | ||
| r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight", | ||
| r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight", | ||
| } | ||
| def get_sd_mapping(has_vision: bool) -> dict: | ||
| model_key = "model.language_model" if has_vision else "model" | ||
| return { | ||
| # Text model keys | ||
| r"^output.weight": r"lm_head.weight", | ||
| r"^norm.weight": rf"{model_key}.norm.weight", | ||
| r"^tok_embeddings.weight": rf"{model_key}.embed_tokens.weight", | ||
| r"^layers.(\d+).attention_norm.weight": rf"{model_key}.layers.\1.input_layernorm.weight", | ||
| r"^layers.(\d+).ffn_norm.weight": rf"{model_key}.layers.\1.post_attention_layernorm.weight", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).weight": rf"{model_key}.layers.\1.self_attn.\2_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w1.weight": rf"{model_key}.layers.\1.mlp.gate_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w2.weight": rf"{model_key}.layers.\1.mlp.down_proj.weight", | ||
| r"^layers.(\d+).feed_forward.w3.weight": rf"{model_key}.layers.\1.mlp.up_proj.weight", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).qscale_act": rf"{model_key}.layers.\1.self_attn.\2_proj.activation_scale", | ||
| r"^layers.(\d+).feed_forward.w1.qscale_act": rf"{model_key}.layers.\1.mlp.gate_proj.activation_scale", | ||
| r"^layers.(\d+).feed_forward.w2.qscale_act": rf"{model_key}.layers.\1.mlp.down_proj.activation_scale", | ||
| r"^layers.(\d+).feed_forward.w3.qscale_act": rf"{model_key}.layers.\1.mlp.up_proj.activation_scale", | ||
| r"^layers.(\d+).attention.w(q|k|v|o).qscale_weight": rf"{model_key}.layers.\1.self_attn.\2_proj.weight_scale_inv", | ||
| r"^layers.(\d+).feed_forward.w1.qscale_weight": rf"{model_key}.layers.\1.mlp.gate_proj.weight_scale_inv", | ||
| r"^layers.(\d+).feed_forward.w2.qscale_weight": rf"{model_key}.layers.\1.mlp.down_proj.weight_scale_inv", | ||
| r"^layers.(\d+).feed_forward.w3.qscale_weight": rf"{model_key}.layers.\1.mlp.up_proj.weight_scale_inv", | ||
|
|
||
| # Vision model keys | ||
| r"vision_encoder.transformer.layers.(\d+).attention_norm.weight": r"model.vision_tower.transformer.layers.\1.attention_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).ffn_norm.weight": r"model.vision_tower.transformer.layers.\1.ffn_norm.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).attention.w(q|k|v|o).weight": r"model.vision_tower.transformer.layers.\1.attention.\2_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w1.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.gate_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w2.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.down_proj.weight", | ||
| r"^vision_encoder.transformer.layers.(\d+).feed_forward.w3.weight": r"model.vision_tower.transformer.layers.\1.feed_forward.up_proj.weight", | ||
| r"^vision_language_adapter.w_in": r"model.multi_modal_projector.linear_1", | ||
| r"^vision_language_adapter.w_out": r"model.multi_modal_projector.linear_2", | ||
| r"^vision_encoder.ln_pre.weight": r"model.vision_tower.ln_pre.weight", | ||
| r"^vision_encoder.patch_conv.weight": r"model.vision_tower.patch_conv.weight", | ||
| r"^patch_merger.merging_layer.weight": r"model.multi_modal_projector.patch_merger.merging_layer.weight", | ||
| r"^pre_mm_projector_norm.weight": r"model.multi_modal_projector.norm.weight", | ||
| } | ||
| # fmt: on | ||
|
|
||
|
|
||
| def map_old_key_to_new(old_key): | ||
| def map_old_key_to_new(old_key, mapping): | ||
| """Map of a key of the original state dict to the equivalent key in HF format""" | ||
| for pattern, replacement in STATE_DICT_MAPPING.items(): | ||
| for pattern, replacement in mapping.items(): | ||
| new_key, n_replace = re.subn(pattern, replacement, old_key) | ||
| # Early exit of the loop | ||
| if n_replace > 0: | ||
|
|
@@ -100,11 +102,13 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config): | |
| """Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case).""" | ||
| new_dict = {} | ||
|
|
||
| is_vision = isinstance(config, Mistral3Config) | ||
| mapping = get_sd_mapping(is_vision) | ||
| for old_key, tensor in original_state_dict.items(): | ||
| if "fake_quantizer" in old_key: | ||
| continue | ||
|
|
||
| new_key = map_old_key_to_new(old_key) | ||
| new_key = map_old_key_to_new(old_key, mapping) | ||
|
|
||
| if "vision" in old_key: | ||
| num_attention_heads = config.vision_config.num_attention_heads | ||
|
|
@@ -114,10 +118,11 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config): | |
| key_value_dim = head_dim * num_attention_heads | ||
| query_dim = head_dim * num_attention_heads | ||
| else: | ||
| num_attention_heads = config.text_config.num_attention_heads | ||
| hidden_size = config.text_config.hidden_size | ||
| head_dim = config.text_config.head_dim | ||
| num_key_value_heads = config.text_config.num_key_value_heads | ||
| text_config = config.text_config if is_vision else config | ||
| num_attention_heads = text_config.num_attention_heads | ||
| hidden_size = text_config.hidden_size | ||
| head_dim = text_config.head_dim | ||
| num_key_value_heads = text_config.num_key_value_heads | ||
| key_value_dim = head_dim * num_key_value_heads | ||
| query_dim = head_dim * num_attention_heads | ||
|
|
||
|
|
@@ -130,8 +135,11 @@ def convert_state_dict(original_state_dict: dict, config: Mistral3Config): | |
| return new_dict | ||
|
|
||
|
|
||
| def convert_config(original_config: dict, max_position_embeddings: int = 262144): | ||
| def convert_config(original_config: dict, max_position_embeddings: int = 262144, is_vision: bool = True): | ||
| original_vision_config = original_config.pop("vision_encoder", None) | ||
| assert is_vision == (original_vision_config is not None), ( | ||
| f"is_vision={is_vision} but original_vision_config={original_vision_config}" | ||
| ) | ||
| original_text_config = original_config | ||
|
|
||
| # Text config | ||
|
|
@@ -159,9 +167,9 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) | |
| "original_max_position_embeddings": original_config["yarn"]["original_max_position_embeddings"], | ||
| "beta_fast": float(original_config["yarn"]["beta"]), | ||
| "beta_slow": float(original_config["yarn"]["alpha"]), | ||
| "mscale_all_dim": 1.0, | ||
| "mscale_all_dim": 1.0 if is_vision else 0.0, | ||
| "mscale": 1.0, | ||
| "llama_4_scaling_beta": original_config["llama_4_scaling"]["beta"], | ||
| "llama_4_scaling_beta": original_config.get("llama_4_scaling", {}).get("beta", 0), | ||
| } | ||
|
|
||
| # These are not always defined depending on `params.json` | ||
|
|
@@ -173,11 +181,25 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) | |
| if new_text_config_kwargs["sliding_window"] is not None: | ||
| new_text_config_kwargs["sliding_window"] = int(new_text_config_kwargs["sliding_window"]) | ||
|
|
||
| new_text_config = Ministral3Config(**new_text_config_kwargs) | ||
| def get_maybe_quant_config() -> dict: | ||
| kwargs = {} | ||
| if original_config.get("quantization", {}).get("qformat_weight") == "fp8_e4m3": | ||
| assert original_config["quantization"]["qscheme_act"] == "TENSOR" | ||
| quantization_config = { | ||
| "activation_scheme": "static", | ||
| "modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector", "lm_head"], | ||
| "quant_method": "fp8", | ||
| "weight_block_size": None, | ||
| } | ||
| kwargs["quantization_config"] = AutoQuantizationConfig.from_dict(quantization_config) | ||
| return kwargs | ||
|
|
||
| # No vision | ||
| if original_vision_config is None: | ||
| new_text_config = Ministral3Config(**new_text_config_kwargs, **get_maybe_quant_config()) | ||
| return new_text_config | ||
| else: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to remove the else since we do an early return if the condition evaluates to true? |
||
| new_text_config = Ministral3Config(**new_text_config_kwargs) | ||
|
|
||
| # Vision config | ||
| new_vision_config = original_vision_config | ||
|
|
@@ -191,25 +213,14 @@ def convert_config(original_config: dict, max_position_embeddings: int = 262144) | |
| _ = new_vision_config.pop("max_image_size") | ||
| new_vision_config = PixtralVisionConfig(hidden_act="silu", **new_vision_config) | ||
|
|
||
| kwargs = {} | ||
| if original_config.get("quantization", {}).get("qformat_weight") == "fp8_e4m3": | ||
| assert original_config["quantization"]["qscheme_act"] == "TENSOR" | ||
| quantization_config = { | ||
| "activation_scheme": "static", | ||
| "modules_to_not_convert": ["model.vision_tower", "model.multi_modal_projector"], | ||
| "quant_method": "fp8", | ||
| "weight_block_size": None, | ||
| } | ||
| kwargs["quantization_config"] = AutoQuantizationConfig.from_dict(quantization_config) | ||
|
|
||
| new_config = Mistral3Config( | ||
| vision_config=new_vision_config, | ||
| text_config=new_text_config, | ||
| multimodal_projector_bias=adapter_bias, | ||
| image_token_id=image_token_id, | ||
| spatial_merge_size=spatial_merge_size, | ||
| vision_feature_layer=-1, | ||
| **kwargs, | ||
| **get_maybe_quant_config(), | ||
| ) | ||
| return new_config | ||
|
|
||
|
|
@@ -218,7 +229,8 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd | |
| """Convert the model and save it (this implicitly save the config as well).""" | ||
| params = read_json(os.path.join(input_dir, "params.json")) | ||
|
|
||
| config = convert_config(params, max_position_embeddings) | ||
| is_vision = params.get("vision_encoder") is not None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to default to "False" instead? |
||
| config = convert_config(params, max_position_embeddings, is_vision) | ||
|
|
||
| full_state_dict = {} | ||
| # The model may be split between different files, but a single nn.Module is always fully present in a single file | ||
|
|
@@ -228,8 +240,10 @@ def convert_and_write_model(input_dir: str, output_dir: str, max_position_embedd | |
| new_dict = convert_state_dict(original_state_dict, config) | ||
| full_state_dict.update(new_dict) | ||
|
|
||
| if config.text_config.tie_word_embeddings: | ||
| full_state_dict["lm_head.weight"] = full_state_dict["model.language_model.embed_tokens.weight"] | ||
| text_config = config.text_config if is_vision else config | ||
| if text_config.tie_word_embeddings: | ||
| model_key = "model.language_model" if is_vision else "model" | ||
| full_state_dict["lm_head.weight"] = full_state_dict[f"{model_key}.embed_tokens.weight"] | ||
|
|
||
| # Load weights into model and resave them | ||
| with torch.device("meta"): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just add the quantization config in
new_configand not potentially in thenew_text_config. In theconfig.jsonthat you converted, everything looks good