Created
July 23, 2024 17:42
-
-
Save ejmejm/3cf43081a457270912d31a6de1a500a9 to your computer and use it in GitHub Desktop.
Test of vmapped matrix multiplication vs. batched matrix multiplication
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from equinox import nn | |
import jax | |
import jax.numpy as jnp | |
######################################### | |
# Test 1 # | |
######################################### | |
n = 100 | |
a = jax.random.normal(jax.random.PRNGKey(0), (n, 256)) | |
b = jax.random.normal(jax.random.PRNGKey(1), (256, 512)) | |
matmul = jax.jit(jnp.matmul) | |
batch_matmul = jax.jit(jax.vmap(jnp.matmul, in_axes=(0, None))) | |
# Compile | |
jax.block_until_ready(matmul(jnp.ones_like(a), jnp.ones_like(b))) | |
jax.block_until_ready(batch_matmul(jnp.ones_like(a), jnp.ones_like(b))) | |
# Time each function | |
%timeit jax.block_until_ready(matmul(a, b)) | |
%timeit jax.block_until_ready(batch_matmul(a, b)) | |
# Output: | |
# 180 µs ± 44.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) | |
# 214 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) | |
######################################### | |
# Test 2 # | |
######################################### | |
n = 100 | |
a = jax.random.normal(jax.random.PRNGKey(2), (n, 256)) | |
b = jax.random.normal(jax.random.PRNGKey(3), (n, 256, 512)) | |
batch_matmul = jax.jit(jax.lax.batch_matmul) | |
vmap_batch_matmul = jax.jit(jax.vmap(jnp.matmul)) | |
# Compile | |
jax.block_until_ready(batch_matmul(jnp.ones_like(a)[:, None, :], jnp.ones_like(b))) | |
jax.block_until_ready(vmap_batch_matmul(jnp.ones_like(a), jnp.ones_like(b))) | |
# Time each function | |
%timeit jax.block_until_ready(batch_matmul(a[:, None, :], b)) | |
%timeit jax.block_until_ready(vmap_batch_matmul(a, b)) | |
# Output: | |
# 366 µs ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) | |
# 1.25 ms ± 44.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) | |
######################################### | |
# Test 3 # | |
######################################### | |
n = 100 | |
a = jax.random.normal(jax.random.PRNGKey(4), (n, 256)) | |
linear = nn.Linear(256, 512, key=jax.random.PRNGKey(5)) | |
def batch_linear(a, weight, bias): | |
return a @ weight.T + bias | |
linear_fn = jax.jit(batch_linear) | |
vmap_linear_fn = jax.jit(jax.vmap(linear.__call__)) | |
# Compile | |
jax.block_until_ready(linear_fn(jnp.ones_like(a), jnp.ones_like(linear.weight), jnp.ones_like(linear.bias))) | |
jax.block_until_ready(vmap_linear_fn(jnp.ones_like(a))) | |
# Time each function | |
%timeit jax.block_until_ready(linear_fn(a, linear.weight, linear.bias)) | |
%timeit jax.block_until_ready(vmap_linear_fn(a)) | |
# Output: | |
# 254 µs ± 7.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) | |
# 251 µs ± 3.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment