Skip to content

Instantly share code, notes, and snippets.

@awni
Last active January 25, 2025 21:30

Revisions

  1. awni revised this gist Aug 23, 2024. 1 changed file with 12 additions and 23 deletions.
    35 changes: 12 additions & 23 deletions l3min.py
    Original file line number Diff line number Diff line change
    @@ -46,31 +46,23 @@ def __init__(
    freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
    wavelens = 2 * mx.pi * freqs

    smooths = (wavelens - high_freq_wavelen) / (
    low_freq_wavelen - high_freq_wavelen
    freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
    is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
    smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
    high_freq_factor - low_freq_factor
    )
    new_base_freqs = freqs * (1 - smooths) * factor + smooths
    new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
    new_base_freqs = mx.where(
    wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
    )
    self.base = new_base_freqs.mean().item()
    smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
    self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)

    def __call__(self, x, offset=0):
    seq_len = x.shape[1] + offset
    base = self.base
    if seq_len > self.max_position_embeddings:
    base *= (seq_len / self.max_position_embeddings) ** (
    self.dims / (self.dims - 2)
    )

    return mx.fast.rope(
    x,
    self.dims,
    traditional=False,
    base=base,
    base=None,
    scale=1.0,
    offset=offset,
    freqs=self._freqs,
    )


    @@ -82,7 +74,7 @@ def __init__(self, args):
    self.n_heads = n_heads = args.num_attention_heads
    self.n_kv_heads = n_kv_heads = args.num_key_value_heads

    self.head_dim = head_dim = args.hidden_size // n_heads
    head_dim = args.hidden_size // n_heads
    self.scale = head_dim ** (-0.5)

    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
    @@ -124,10 +116,8 @@ def __call__(self, x, mask=None, cache=None):


    class MLP(nn.Module):
    def __init__(self, args):
    def __init__(self, dim, hidden_dim):
    super().__init__()
    dim = args.hidden_size
    hidden_dim = args.intermediate_size
    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
    @@ -140,7 +130,7 @@ class TransformerBlock(nn.Module):
    def __init__(self, args):
    super().__init__()
    self.self_attn = Attention(args)
    self.mlp = MLP(args)
    self.mlp = MLP(args.hidden_size, args.intermediate_size)
    self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
    self.post_attention_layernorm = nn.RMSNorm(
    args.hidden_size, eps=args.rms_norm_eps
    @@ -149,8 +139,7 @@ def __init__(self, args):
    def __call__(self, x, mask=None, cache=None):
    r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
    h = x + r
    r = self.mlp(self.post_attention_layernorm(h))
    out = h + r
    out = h + self.mlp(self.post_attention_layernorm(h))
    return out, cache


  2. awni revised this gist Aug 23, 2024. No changes.
  3. awni revised this gist Aug 9, 2024. 1 changed file with 2 additions and 5 deletions.
    7 changes: 2 additions & 5 deletions l3min.py
    Original file line number Diff line number Diff line change
    @@ -245,10 +245,8 @@ def generate(
    ):
    print("[INFO] Loading model from disk.")
    model, tokenizer = load(model)
    messages = [{"role": "user", "content": prompt}]
    prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    [{"role": "user", "content": prompt}],
    add_generation_prompt=True,
    return_tensors="mlx",
    )
    @@ -260,8 +258,7 @@ def generate(
    for token, n in zip(generate_step(prompt, model), range(max_tokens)):
    tokens.append(token)
    if n == 0:
    toc = time.time()
    prompt_tps = prompt.size / (toc - tic)
    prompt_tps = prompt.size / (time.time() - tic)
    tic = time.time()

    if token == tokenizer.eos_token_id:
  4. awni revised this gist Aug 9, 2024. 1 changed file with 62 additions and 15 deletions.
    77 changes: 62 additions & 15 deletions l3min.py
    Original file line number Diff line number Diff line change
    @@ -2,14 +2,14 @@
    A minimal, fast example generating text with Llama 3.1 in MLX.
    To run, install the requirements:
    pip install -U mlx transformers fire
    Then generate text with:
    python l3min.py "How tall is K2?"
    """

    import fire
    import json
    import glob
    @@ -20,7 +20,7 @@
    import time
    from transformers import AutoTokenizer
    from types import SimpleNamespace


    class DynamicNTKScalingRoPE(nn.Module):

    @@ -34,18 +34,18 @@ def __init__(
    super().__init__()
    self.dims = dims
    self.max_position_embeddings = max_position_embeddings

    factor = rope_scaling["factor"]
    low_freq_factor = rope_scaling["low_freq_factor"]
    high_freq_factor = rope_scaling["high_freq_factor"]
    old_context_len = rope_scaling["original_max_position_embeddings"]

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
    wavelens = 2 * mx.pi * freqs

    smooths = (wavelens - high_freq_wavelen) / (
    low_freq_wavelen - high_freq_wavelen
    )
    @@ -55,15 +55,15 @@ def __init__(
    wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
    )
    self.base = new_base_freqs.mean().item()

    def __call__(self, x, offset=0):
    seq_len = x.shape[1] + offset
    base = self.base
    if seq_len > self.max_position_embeddings:
    base *= (seq_len / self.max_position_embeddings) ** (
    self.dims / (self.dims - 2)
    )
    )

    return mx.fast.rope(
    x,
    self.dims,
    @@ -134,8 +134,8 @@ def __init__(self, args):

    def __call__(self, x):
    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


    class TransformerBlock(nn.Module):
    def __init__(self, args):
    super().__init__()
    @@ -153,7 +153,7 @@ def __call__(self, x, mask=None, cache=None):
    out = h + r
    return out, cache


    class LlamaModel(nn.Module):
    def __init__(self, args):
    super().__init__()
    @@ -188,7 +188,54 @@ def __init__(self, args):

    def __call__(self, inputs, cache=None):
    out, cache = self.model(inputs, cache)
    return self.lm_head(out), cache
    return self.lm_head(out), cache


    def load(hf_repo):
    model_path = Path(
    snapshot_download(
    repo_id=hf_repo,
    allow_patterns=["*.json", "*.safetensors"],
    )
    )
    with open(model_path / "config.json", "r") as f:
    config = json.load(f)

    weight_files = glob.glob(str(model_path / "model*.safetensors"))
    weights = {}
    for wf in weight_files:
    weights.update(mx.load(wf))

    model = Model(SimpleNamespace(**config))

    if (quantization := config.get("quantization", None)) is not None:
    nn.quantize(model, **quantization)

    model.load_weights(list(weights.items()))

    mx.eval(model.parameters())

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.decode([0])

    return model, tokenizer


    def generate_step(prompt, model):
    cache = None

    def _step(y):
    nonlocal cache
    logits, cache = model(y, cache=cache)
    return mx.argmax(logits[:, -1, :], axis=-1)

    y = _step(prompt)
    mx.async_eval(y)
    while True:
    next_y = _step(y[None])
    mx.async_eval(next_y)
    yield y.item()
    y = next_y


    def generate(
    @@ -236,4 +283,4 @@ def generate(


    if __name__ == "__main__":
    fire.Fire(generate)
    fire.Fire(generate)
  5. awni revised this gist Aug 9, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion l3min.py
    Original file line number Diff line number Diff line change
    @@ -7,7 +7,7 @@
    Then generate text with:
    python l3.py "How tall is K2?"
    python l3min.py "How tall is K2?"
    """

    import fire
  6. awni created this gist Aug 9, 2024.
    239 changes: 239 additions & 0 deletions l3min.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,239 @@
    """
    A minimal, fast example generating text with Llama 3.1 in MLX.
    To run, install the requirements:
    pip install -U mlx transformers fire
    Then generate text with:
    python l3.py "How tall is K2?"
    """

    import fire
    import json
    import glob
    from huggingface_hub import snapshot_download
    import mlx.core as mx
    import mlx.nn as nn
    from pathlib import Path
    import time
    from transformers import AutoTokenizer
    from types import SimpleNamespace


    class DynamicNTKScalingRoPE(nn.Module):

    def __init__(
    self,
    dims,
    rope_scaling,
    max_position_embeddings=2048,
    base=10000,
    ):
    super().__init__()
    self.dims = dims
    self.max_position_embeddings = max_position_embeddings

    factor = rope_scaling["factor"]
    low_freq_factor = rope_scaling["low_freq_factor"]
    high_freq_factor = rope_scaling["high_freq_factor"]
    old_context_len = rope_scaling["original_max_position_embeddings"]

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor

    freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
    wavelens = 2 * mx.pi * freqs

    smooths = (wavelens - high_freq_wavelen) / (
    low_freq_wavelen - high_freq_wavelen
    )
    new_base_freqs = freqs * (1 - smooths) * factor + smooths
    new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
    new_base_freqs = mx.where(
    wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
    )
    self.base = new_base_freqs.mean().item()

    def __call__(self, x, offset=0):
    seq_len = x.shape[1] + offset
    base = self.base
    if seq_len > self.max_position_embeddings:
    base *= (seq_len / self.max_position_embeddings) ** (
    self.dims / (self.dims - 2)
    )

    return mx.fast.rope(
    x,
    self.dims,
    traditional=False,
    base=base,
    scale=1.0,
    offset=offset,
    )


    class Attention(nn.Module):
    def __init__(self, args):
    super().__init__()

    dim = args.hidden_size
    self.n_heads = n_heads = args.num_attention_heads
    self.n_kv_heads = n_kv_heads = args.num_key_value_heads

    self.head_dim = head_dim = args.hidden_size // n_heads
    self.scale = head_dim ** (-0.5)

    self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
    self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
    self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
    self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

    self.rope = DynamicNTKScalingRoPE(
    dims=head_dim,
    rope_scaling=args.rope_scaling,
    max_position_embeddings=args.max_position_embeddings,
    base=args.rope_theta,
    )

    def __call__(self, x, mask=None, cache=None):
    B, L, _ = x.shape

    queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

    queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
    keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
    values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

    if cache is not None:
    key_cache, value_cache = cache
    queries = self.rope(queries, offset=key_cache.shape[2])
    keys = self.rope(keys, offset=key_cache.shape[2])
    keys = mx.concatenate([key_cache, keys], axis=2)
    values = mx.concatenate([value_cache, values], axis=2)
    else:
    queries = self.rope(queries)
    keys = self.rope(keys)

    output = mx.fast.scaled_dot_product_attention(
    queries, keys, values, mask=mask, scale=self.scale
    )
    output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
    return self.o_proj(output), (keys, values)


    class MLP(nn.Module):
    def __init__(self, args):
    super().__init__()
    dim = args.hidden_size
    hidden_dim = args.intermediate_size
    self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
    self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
    self.up_proj = nn.Linear(dim, hidden_dim, bias=False)

    def __call__(self, x):
    return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


    class TransformerBlock(nn.Module):
    def __init__(self, args):
    super().__init__()
    self.self_attn = Attention(args)
    self.mlp = MLP(args)
    self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
    self.post_attention_layernorm = nn.RMSNorm(
    args.hidden_size, eps=args.rms_norm_eps
    )

    def __call__(self, x, mask=None, cache=None):
    r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
    h = x + r
    r = self.mlp(self.post_attention_layernorm(h))
    out = h + r
    return out, cache


    class LlamaModel(nn.Module):
    def __init__(self, args):
    super().__init__()
    self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
    self.layers = [
    TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
    ]
    self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

    def __call__(self, inputs, cache=None):
    h = self.embed_tokens(inputs)

    mask = None
    if h.shape[1] > 1:
    mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
    mask = mask.astype(h.dtype)

    if cache is None:
    cache = [None] * len(self.layers)

    for e, layer in enumerate(self.layers):
    h, cache[e] = layer(h, mask, cache[e])

    return self.norm(h), cache


    class Model(nn.Module):
    def __init__(self, args):
    super().__init__()
    self.model = LlamaModel(args)
    self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

    def __call__(self, inputs, cache=None):
    out, cache = self.model(inputs, cache)
    return self.lm_head(out), cache


    def generate(
    prompt,
    model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
    max_tokens=128,
    ):
    print("[INFO] Loading model from disk.")
    model, tokenizer = load(model)
    messages = [{"role": "user", "content": prompt}]
    prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="mlx",
    )

    print("[INFO] Starting generation...")
    tic = time.time()
    s = 0
    tokens = []
    for token, n in zip(generate_step(prompt, model), range(max_tokens)):
    tokens.append(token)
    if n == 0:
    toc = time.time()
    prompt_tps = prompt.size / (toc - tic)
    tic = time.time()

    if token == tokenizer.eos_token_id:
    break

    words = tokenizer.decode(tokens)
    print(words[s:], end="", flush=True)
    if words[-1] == "\n":
    tokens = []
    s = 0
    else:
    s = len(words)

    print(tokenizer.decode(tokens)[s:], flush=True)
    gen_tps = (n + 1) / (time.time() - tic)
    print("=" * 10)
    print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
    print(f"Generation: {gen_tps:.3f} tokens-per-sec")


    if __name__ == "__main__":
    fire.Fire(generate)