Skip to content

Instantly share code, notes, and snippets.

@awni
Last active July 24, 2025 13:18
Show Gist options
  • Save awni/4beb1f7dfefc6f9426f3a7deee74af50 to your computer and use it in GitHub Desktop.
Save awni/4beb1f7dfefc6f9426f3a7deee74af50 to your computer and use it in GitHub Desktop.
Writing Fast MLX

Making MLX Go Fast

This guide assumes you have some familiarity with MLX and want to make your MLX model or algorithm as efficient as possible.

The guide covers the following:

Graph Evaluation

Recall, MLX is lazy. When you call an MLX op, no computation actually happens. You are simply building a graph. The computation happens when you explicitly or implicitly evaluate an array. Read more about how this works in the documentation.

Evaluating the graph incurs some overhead, so don't do it too frequently. Conversely you don't want the graph to get too big before evaluating it as this can also be expensive. Most numerical and machine learning algorithms are iterative. A good place to evaluate the graph is at the end of each iteration. Some examples:

  • After an iteration of gradient descent
  • After producing one token with a language model
  • After taking one denoising step in a diffusion model

Overly frequent evaluations sometimes happen by accident. For example:

# output is an mx.array
for x in output:
  do_something(x.item())

The same thing can be written more explicitly with operations and mx.eval as:

for i in range(len(output)):    
  x = output[i]
  mx.eval(x)
  do_something(x.item())

Two better options are:

  1. When possible avoid calling item() and do everything in MLX.
  2. Move the entire output to Python or NumPy first.

An example of the second approach:

for x in output.tolist():
  do_something(x)

Asynchronous Evaluation

For a latency sensitive computation which is run many times, mx.async_eval can be useful. Normally mx.eval is synchronous. It returns only when the computation is complete. Instead mx.async_eval asynchronously evaluates the graph and returns to the main thread immediately. You can use this to pipeline graph construction with computation like so:

def generator():
    out = mx.async_eval(my_function())

    while True:
        out_next = mx.async_eval(my_function())
        mx.eval(out)
        yield out
        out = out_next

For this to work my_function() cannot do any synchronous evaluations (e.g. calling mx.eval, converting to NumPy, etc.). Furthermore, any work done on out that is synchronous and on the same stream can stall the pipeline:

for out in generator():
    out = out * 2
    # Stalls the pipeline!
    mx.eval(out)

An easy fix for this is to put the pipeline in a separate stream:

def generator():
    with mx.stream(mx.new_stream(mx.gpu)):
        out = mx.async_eval(my_function())

        while True:
            out_next = mx.async_eval(my_function())
            mx.eval(out)
            yield out
            out = out_next

Type Promotion

One of the most common performance issues comes from accidental up-casting. Make sure you understand how type promotion works in MLX. The inputs to an MLX operation are typically promoted to a common type which doesn't lose precision. For example:

x = mx.array(1.0, mx.float32) * mx.array(2.0, mx.float16)

will result in x with type mx.float32. Similarly:

x = mx.array(1.0, mx.bfloat16) * mx.array(2.0, mx.float16)

will result in x with type mx.float32. A common mistake is to multiply a half-precision array by a default-typed scalar array which up-casts everything to mx.float32:

# Warning: x has type mx.float32
x = my_fp16_array * mx.array(2.0)

To multiply by a scalar while preserving the input type, use Python scalars. Python scalars are weakly typed and have more relaxed promotion rules when used with MLX arrays.

# Ok, x has type mx.float16
x = my_fp16_array * 2.0

Operations

Use Fast Ops

Use mx.fast ops when possible:

  • mx.fast.rms_norm
  • mx.fast.layer_norm
  • mx.fast.rope
  • mx.fast.scaled_dot_product_attention

A lot of these operations take a variety of parameters so they can be used for different variations of the function. For example, the weight and bias parameters are optional in mx.fast.layer_norm so it can be used with different permutations of inputs.

Precision

For operations which typically use higher precision there is usually no need to explicitly upcast. For example, mx.fast.rms_norm and mx.fast.layer_norm accumulate in higher precision so it's wasteful to upcast and downcast into and out of these operations:

# No need for this!
mx.fast.rms_norm(x.astype(mx.float32), w, b, eps).astype(x.dtype)

# This is just as good:
mx.fast.rms_norm(x, w, b, eps)

Similarly, for mx.softmax use precise=True if you want to do the softmax in higher precision rather than explicitly casting the input and output.

Misc

  • For vector-matrix multiplication x @ W.T is faster than x @ W, for matrix-vector multiplication W @ x is faster than W.T @ x
  • Use mx.addmm for a @ b + c (e.g. a linear layer with a bias).
  • Where it makes sense, use mx.take_along_axis and mx.put_along_axis instead of fancy indexing
  • Use broadcasting instead of concatenation. For example, prefer mx.repeat(a, n) over mx.concatenate([a]*n)

Compile

Compiling graphs with mx.compile can make them run a lot faster. But there are some sharp-edges that are good to be aware of.

First, be aware of when a function will be recompiled. Recompilation is relatively expensive and should only be done if there is sufficient work over which to amortize the cost.

The default behavior of mx.compile is to do a shape-dependent compilation. This means the function will be recompiled if the shape of any input changes.

MLX supports a shapeless compilation by passing shapeless=True to mx.compile. It's easy to make hard-to-detect mistakes with shapeless compilation. Make sure to read and understand the documentation and use it with care.

A function will also be recompiled if any constant inputs change:

@mx.compile
def fun(x, scale):
  return scale * x

fun(x, 3)

# Recompiles!
fun(x, 4)

In this case a simple fix is to make scale an mx.array.

Compiling Closures

Be careful when compiling a closure where the function encloses any mx.array.

y = some_function()

@mx.compile
def fun(x):
  return x + y

Since y is not an input to fun, the compiled graph will include the entire computation which produces y. Usually you only want to compute y one time and re-use it in the compiled function. Either explicitly pass it as in input to fun or pass it as an implicit input to mx.compile like so:

y = some_function()

@partial(mx.compile, inputs=[y])
def fun(x):
  return x + y

Memory Use

Lazy Loading

Loading arrays from a file is lazy in MLX:

weights = mx.load("model.safetensors")

The above function returns instantly, regardless of the file size. To actually load the weights into memory, you can do mx.eval(weights).

Assume the weights are stored on disk in 32-bit precision (i.e. mx.float32). But for your model you only need 16-bit precision:

weights = mx.load("model.safetensors")
mx.eval(weights)
weights = {k: v.astype(mx.float16) for k, v in weights.items()}

In the above, the weights will be loaded into memory in full precision and then cast to 16-bit. This requires memory for all the weights in 32-bit plus memory for the weights in 16-bit.

This is much better:

weights = mx.load("model.safetensors")
weights = {k: v.astype(mx.float16) for k, v in weights.items()}
mx.eval(weights)

Evaluating after the cast to mx.float16 reduces peak memory by nearly a third. That's because all the weights are never fully materialized in 32-bit. Right after each weight is loaded in 32-bit precision it is cast to 16-bit. The memory for the 32-bit weight can be reused when loading the next weight.

Note, MLX is only able to lazy load from a file when it is given to mx.load as a string path. Due to lifetime management issues, lazy loading from file handles is not supported. So avoid this:

weights = mx.load(open("model.safetensors", 'rb'))

Release Temporaries

One way to reduce memory consumption is to avoid holding temporaries you don't need. This is a typical training loop:

for x, y in dataset:
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    mx.eval(model, optimizer.state)

It's suboptimal since a reference to grads is held during the call to mx.eval which keeps the respective memory from being used for any other part of the computation.

This is better:

def step(x, y):
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss
    
for x, y in dataset:
    loss = step(x, y)
    mx.eval(model, optimizer.state)

In this case the reference to grads is released before mx.eval and the memory can be reused. You can achieve the same goal using del as long as it's before the call to mx.eval:

for x, y in dataset:
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    del grads
    mx.eval(model, optimizer.state)

Misc

  • MLX will cache memory buffers of recently released arrays rather than returning them to the system. In some cases, especially for variable shape computations, the cache can get large. To help with this, MLX has some functions for logging and customizing the behavior of memory allocation.

Profiling

A good first step is to check GPU utilization using, for example, mactop. If it's not pegged at close to 100% then there is likely a non-MLX bottleneck somewhere in the program. A common culprit is data loading or preprocessing.

If GPU utilization is good, a good next step is to figure out which operations are taking up so much time. One way to do this is with the Metal debugger. For that, see the documentation on profiling MLX with the Metal debugger.

@sck-at-ucy
Copy link

Very useful reference, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment