Last active
February 1, 2025 03:41
-
-
Save apivovarov/9a3b424f19a7522ac348372a8f5f6603 to your computer and use it in GitHub Desktop.
jax softmax accuracy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ======== softmax ======================================== | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
print("JAX devices:", jax.devices()) | |
def softmax(x): | |
x_max = np.max(x, axis=-1, keepdims=True) | |
exp_x = np.exp(x - x_max) | |
return exp_x / np.sum(exp_x, axis=-1, keepdims=True) | |
shape = (8, 128) | |
a_np = np.random.uniform(low=0.0, high=10.0, size=shape).astype(np.float64) | |
c_np = softmax(a_np) | |
@jax.jit | |
def myfunc(a): | |
result = jax.nn.softmax(a, axis=-1) | |
return result | |
# upcast operand to float32 | |
@jax.jit | |
def myfunc2(a): | |
a2 = a.astype(jnp.float32) | |
r2 = jax.nn.softmax(a2, axis=-1) | |
res = r2.astype(jnp.bfloat16) | |
return res | |
for dtype in [jnp.float32, jnp.float16, jnp.bfloat16]: | |
a = jnp.array(a_np, dtype=dtype) | |
y = myfunc(a) | |
y_np = np.asarray(y) | |
arr1 = c_np | |
arr2 = y_np | |
abs_diff = np.abs(arr1 - arr2) | |
mean_abs_diff = np.mean(abs_diff) | |
# Relative difference | |
epsilon = 1e-12 # To avoid division by zero | |
rel_diff = abs_diff / (np.abs(arr1) + epsilon) | |
mean_rel_diff = np.mean(rel_diff) | |
# Output results | |
print("dtype:", arr2.dtype) | |
print("Mean Absolute Difference: %.3e" % mean_abs_diff) | |
print("Mean Relative Difference: %.3e" % mean_rel_diff) |
Author
apivovarov
commented
Feb 1, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment