Skip to content

Instantly share code, notes, and snippets.

@fakerybakery
Last active March 24, 2025 10:36
Show Gist options
  • Save fakerybakery/d7e88b46846e929168f6a4280a0f2baf to your computer and use it in GitHub Desktop.
Save fakerybakery/d7e88b46846e929168f6a4280a0f2baf to your computer and use it in GitHub Desktop.
Convert the text portion of Mistral 3.1 -> HF format (IMPORTANT: Does not convert vision layers yet! The resulting model will be a text-only LLM.)
# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import re
import torch
from safetensors.torch import load_file
from transformers import AutoTokenizer, LlamaTokenizerFast, MistralConfig, MistralForCausalLM
from transformers.integrations.mistral import convert_tekken_tokenizer
# fmt: off
STATE_DICT_MAPPING = {
# CausalLM keys
r"^output.weight": r"lm_head.weight",
# Model keys
r"^norm.weight": r"model.norm.weight",
r"^tok_embeddings.weight": r"model.embed_tokens.weight",
# Layers keys
r"^layers.(\d+).attention_norm.weight": r"model.layers.\1.input_layernorm.weight",
r"^layers.(\d+).ffn_norm.weight": r"model.layers.\1.post_attention_layernorm.weight",
# Attention keys
r"^layers.(\d+).attention.w(q|k|v|o).weight": r"model.layers.\1.self_attn.\2_proj.weight",
# MLP keys
r"^layers.(\d+).feed_forward.w1.weight": r"model.layers.\1.mlp.gate_proj.weight",
r"^layers.(\d+).feed_forward.w2.weight": r"model.layers.\1.mlp.down_proj.weight",
r"^layers.(\d+).feed_forward.w3.weight": r"model.layers.\1.mlp.up_proj.weight",
}
# fmt: on
def map_old_key_to_new(old_key):
"""Map of a key of the original state dict to the equivalent key in HF format"""
for pattern, replacement in STATE_DICT_MAPPING.items():
new_key, n_replace = re.subn(pattern, replacement, old_key)
# Early exit of the loop
if n_replace > 0:
return new_key
print(f"Warning: {old_key} could not be mapped")
return None
#raise ValueError(f"Key: {old_key} could not be mapped (check the mapping).")
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def permute_for_rope(tensor, n_heads, dim1, dim2):
"""Permute the weights for the ROPE formulation."""
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
tensor = tensor.transpose(1, 2)
tensor = tensor.reshape(dim1, dim2)
return tensor
def convert_state_dict(original_state_dict: dict, config: MistralConfig):
"""Convert a state dict file, when a single `nn.Module` is never sharded in different files (usual case)."""
new_dict = {}
num_attention_heads = config.num_attention_heads
hidden_size = config.hidden_size
head_dim = config.head_dim
num_key_value_heads = config.num_key_value_heads
key_value_dim = head_dim * num_key_value_heads
query_dim = head_dim * num_attention_heads
for old_key, tensor in original_state_dict.items():
new_key = map_old_key_to_new(old_key)
if new_key is not None: # Only process if the key was successfully mapped
if "q_proj" in new_key:
tensor = tensor.view(num_attention_heads, head_dim, hidden_size).reshape(query_dim, hidden_size)
tensor = permute_for_rope(tensor, num_attention_heads, query_dim, hidden_size)
elif "k_proj" in new_key:
tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size)
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, hidden_size)
elif "v_proj" in new_key:
tensor = tensor.view(num_key_value_heads, head_dim, hidden_size).reshape(key_value_dim, hidden_size)
new_dict[new_key] = tensor
else:
print(f"Skipping unmapped key: {old_key}")
return new_dict
def get_concat_dim(key):
"""Return the dimension to concatenate the weights on."""
concat_dim_1 = [
r"model.embed_tokens.weight",
r"model.layers.(\d+).self_attn.o_proj.weight",
r"model.layers.(\d+).mlp.down_proj.weight",
]
if any(re.search(pattern, key) for pattern in concat_dim_1):
return 1
return 0
def convert_state_dict_sharded(loaded_shards: list[dict], config: MistralConfig):
"""Convert the state dict, when a single `nn.Module` is sharded across different files."""
new_dict = {}
num_shards = len(loaded_shards)
n_heads = config.num_attention_heads
dim = config.hidden_size
dims_per_head = dim // n_heads
num_key_value_heads = config.num_key_value_heads
n_heads_per_shard = n_heads // num_shards
num_local_key_value_heads = num_key_value_heads // num_shards
key_value_dim = dim if n_heads == num_key_value_heads else dims_per_head * num_local_key_value_heads
original_keys = loaded_shards[0].keys()
for old_key in original_keys:
new_key = map_old_key_to_new(old_key)
cat_dim = get_concat_dim(new_key)
if "q_proj" in new_key:
tensor = torch.cat(
[shard.pop(old_key).view(n_heads_per_shard, dims_per_head, dim) for shard in loaded_shards],
dim=cat_dim,
).reshape(dim, dim)
tensor = permute_for_rope(tensor, n_heads, dim, dim)
elif "k_proj" in new_key:
tensor = torch.cat(
[shard.pop(old_key).view(num_local_key_value_heads, dims_per_head, dim) for shard in loaded_shards],
dim=cat_dim,
).reshape(key_value_dim, dim)
tensor = permute_for_rope(tensor, num_key_value_heads, key_value_dim, dim)
elif "v_proj" in new_key:
tensor = torch.cat(
[shard.pop(old_key).view(num_local_key_value_heads, dims_per_head, dim) for shard in loaded_shards],
dim=cat_dim,
).reshape(key_value_dim, dim)
elif "input_layernorm" in new_key or "post_attention_layernorm" in new_key:
tensor = loaded_shards[0][old_key].clone()
elif "model.norm.weight" in new_key:
tensor = loaded_shards[0][old_key]
else:
tensor = torch.cat([shard.pop(old_key) for shard in loaded_shards], dim=cat_dim)
new_dict[new_key] = tensor
return new_dict
def convert_config(original_config: dict, max_position_embeddings: int = 32768):
key_mapping = {
"hidden_size": "dim",
"num_hidden_layers": "n_layers",
"intermediate_size": "hidden_dim",
"num_attention_heads": "n_heads",
"rms_norm_eps": "norm_eps",
}
similar_keys_to_keep = [
"head_dim",
"vocab_size",
]
new_config_kwargs = {k: original_config[v] for k, v in key_mapping.items()}
new_config_kwargs.update({k: v for k, v in original_config.items() if k in similar_keys_to_keep})
# These are not always defined depending on `params.json`
new_config_kwargs["sliding_window"] = original_config.get("sliding_window", None)
new_config_kwargs["num_key_value_heads"] = original_config.get(
"n_kv_heads", new_config_kwargs["num_attention_heads"]
)
new_config_kwargs["rope_theta"] = original_config.get("rope_theta", 10000.0)
new_config_kwargs["max_position_embeddings"] = original_config.get("max_seq_len", max_position_embeddings)
# This may sometimes be a string in `params.json`
if new_config_kwargs["sliding_window"] is not None:
new_config_kwargs["sliding_window"] = int(new_config_kwargs["sliding_window"])
new_config = MistralConfig(**new_config_kwargs)
return new_config
def convert_and_write_model(input_dir: str, output_dir: str, max_position_embeddings: int, modules_are_split: bool):
"""Convert the model and save it (this implicitly save the config as well)."""
params = read_json(os.path.join(input_dir, "params.json"))
config = convert_config(params, max_position_embeddings)
full_state_dict = {}
# The model may be split between different files, but a single nn.Module is always fully present in a single file
if not modules_are_split:
shards = [file for file in os.listdir(input_dir) if file.endswith(".safetensors")]
for shard_file in shards:
original_state_dict = load_file(os.path.join(input_dir, shard_file))
new_dict = convert_state_dict(original_state_dict, config)
full_state_dict.update(new_dict)
# A single nn.Module is split between different checkpoint files
else:
shards = [file for file in os.listdir(input_dir) if re.match(r"consolidated.\d+.pth", file)]
shards = sorted(shards, key=lambda x: int(x.split(".")[1]))
loaded_shards = [torch.load(os.path.join(input_dir, file), map_location="cpu") for file in shards]
full_state_dict = convert_state_dict_sharded(loaded_shards, config)
# Load weights into model and resave them
with torch.device("meta"):
model = MistralForCausalLM(config)
model.load_state_dict(full_state_dict, strict=True, assign=True)
model.save_pretrained(output_dir)
def convert_and_write_tokenizer(input_dir: str, output_dir: str, tokenizer_template_name: str = ""):
"""Convert the tokenizer and save it."""
# Tekken format
if "tekken.json" in os.listdir(input_dir):
tokenizer_file = os.path.join(input_dir, "tekken.json")
tokenizer = convert_tekken_tokenizer(tokenizer_file)
else:
# May have .v3 or .v7 at the end
tokenizer_file = [file for file in os.listdir(input_dir) if "tokenizer.model" in file][0]
tokenizer = LlamaTokenizerFast(os.path.join(input_dir, tokenizer_file))
# Load a chat template from another model
if tokenizer_template_name != "":
template_tok = AutoTokenizer.from_pretrained(tokenizer_template_name)
tokenizer.chat_template = template_tok.chat_template
# Finally save it
tokenizer.save_pretrained(output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir",
help="Location of Mistral weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"output_dir",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--template_name",
type=str,
default="",
help="Another model name from which to copy the chat template.",
)
parser.add_argument(
"--max_position_embeddings",
type=int,
default=32768,
help="`max_position_embeddings` field in the config. This needs to be manually passed (not present anywhere otherwise).",
)
parser.add_argument(
"--modules_are_split",
action="store_true",
help="If passed, then the weights of a single `nn.Module` are assumed to be split between different files.",
)
parser.add_argument(
"--tokenizer_only",
action="store_true",
help="If passed, will only convert the tokenizer.",
)
args = parser.parse_args()
if not args.tokenizer_only:
convert_and_write_model(args.input_dir, args.output_dir, args.max_position_embeddings, args.modules_are_split)
convert_and_write_tokenizer(args.input_dir, args.output_dir, args.template_name)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment