Skip to content

Instantly share code, notes, and snippets.

@apivovarov
Last active February 1, 2025 03:41
Show Gist options
  • Save apivovarov/9a3b424f19a7522ac348372a8f5f6603 to your computer and use it in GitHub Desktop.
Save apivovarov/9a3b424f19a7522ac348372a8f5f6603 to your computer and use it in GitHub Desktop.
jax softmax accuracy
# ======== softmax ========================================
import jax
import jax.numpy as jnp
import numpy as np
print("JAX devices:", jax.devices())
def softmax(x):
x_max = np.max(x, axis=-1, keepdims=True)
exp_x = np.exp(x - x_max)
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
shape = (8, 128)
a_np = np.random.uniform(low=0.0, high=10.0, size=shape).astype(np.float64)
c_np = softmax(a_np)
@jax.jit
def myfunc(a):
result = jax.nn.softmax(a, axis=-1)
return result
# upcast operand to float32
@jax.jit
def myfunc2(a):
a2 = a.astype(jnp.float32)
r2 = jax.nn.softmax(a2, axis=-1)
res = r2.astype(jnp.bfloat16)
return res
for dtype in [jnp.float32, jnp.float16, jnp.bfloat16]:
a = jnp.array(a_np, dtype=dtype)
y = myfunc(a)
y_np = np.asarray(y)
arr1 = c_np
arr2 = y_np
abs_diff = np.abs(arr1 - arr2)
mean_abs_diff = np.mean(abs_diff)
# Relative difference
epsilon = 1e-12 # To avoid division by zero
rel_diff = abs_diff / (np.abs(arr1) + epsilon)
mean_rel_diff = np.mean(rel_diff)
# Output results
print("dtype:", arr2.dtype)
print("Mean Absolute Difference: %.3e" % mean_abs_diff)
print("Mean Relative Difference: %.3e" % mean_rel_diff)
@apivovarov
Copy link
Author

JAX devices: [CpuDevice(id=0)]
dtype: float32
Mean Absolute Difference: 1.661e-09
Mean Relative Difference: 1.777e-07
dtype: float16
Mean Absolute Difference: 1.417e-05
Mean Relative Difference: 1.576e-03
dtype: bfloat16
Mean Absolute Difference: 1.222e-04
Mean Relative Difference: 1.164e-02

JAX devices: [CudaDevice(id=0)] T4 Turing Compute Capability: 7.5
dtype: float32
Mean Absolute Difference: 1.559e-09
Mean Relative Difference: 1.779e-07
dtype: float16
Mean Absolute Difference: 1.315e-05
Mean Relative Difference: 1.520e-03
dtype: bfloat16
Mean Absolute Difference: 1.173e-04
Mean Relative Difference: 7.832e-03

JAX devices: [CudaDevice(id=0)] A100 Ampere Compute Capability: 8.0
dtype: float32
Mean Absolute Difference: 1.879e-09
Mean Relative Difference: 1.865e-07
dtype: float16
Mean Absolute Difference: 1.442e-05
Mean Relative Difference: 1.530e-03
dtype: bfloat16
Mean Absolute Difference: 1.147e-04
Mean Relative Difference: 7.862e-03

JAX devices: [TpuDevice(id=0)]
dtype: float32
Mean Absolute Difference: 8.137e-09
Mean Relative Difference: 1.082e-06
dtype: float16
Mean Absolute Difference: 1.230e-05
Mean Relative Difference: 1.228e-03
dtype: bfloat16
Mean Absolute Difference: 1.069e-04
Mean Relative Difference: 7.876e-03

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment