Created
November 29, 2022 21:33
-
-
Save zhangqiaorjc/22cb60d3e12edd0b81143bd42442221e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from absl.testing import absltest | |
from absl import logging | |
import jax | |
import jax.numpy as jnp | |
def amax(x): | |
return jnp.max(jnp.abs(x)) | |
def get_scale(amax_v): | |
return 1.1 * (amax_v / 448) | |
fake_fp8 = jnp.float16 | |
fake_bf16 = jnp.float32 | |
def quantize(x, amax_v): | |
return (x / get_scale(amax_v)).astype(fake_fp8) | |
def dequantize(x, amax_v): | |
return x.astype(fake_bf16) * get_scale(amax_v) | |
# RFC Approach (https://github.com/openxla/xla/discussions/22) | |
def matmul_f8_rfc_fp8_inp(x_fp8, y_fp8, x_amax, y_amax, z_amax): | |
x_rounded = dequantize(x_fp8, x_amax) | |
y_rounded = dequantize(y_fp8, y_amax) | |
z_bf16 = jnp.dot(x_rounded, y_rounded) | |
new_z_amax = amax(z_bf16) | |
z_fp8 = quantize(z_bf16, z_amax) | |
return z_fp8, new_z_amax | |
class Fp8Test(absltest.TestCase): | |
def test_quantize(self): | |
logging.info('amax = %s', amax(jnp.array([1e-15, 4, 27]))) | |
logging.info('scale = %s', get_scale(449)) | |
fp8_v = quantize(449, amax(449)) | |
logging.info('scaled to fp8 = %s', fp8_v) | |
logging.info('unscaled to bf16 = %s', dequantize(fp8_v, amax(449))) | |
def test_matmul(self): | |
A = jnp.ones((2, 4), dtype=fake_bf16) | |
B = jnp.ones((4, 2), dtype=fake_bf16) | |
A_amax = amax(A) | |
B_amax = amax(B) | |
A_fp8 = quantize(A, A_amax) | |
B_fp8 = quantize(B, B_amax) | |
C_fp8, C_amax = matmul_f8_rfc_fp8_inp(A_fp8, B_fp8, A_amax, B_amax, amax(jnp.dot(A, B))) | |
C = dequantize(C_fp8, C_amax) | |
logging.info('matmul_f8_rfc_fp8_inp: %s', C) | |
logging.info('matmul %s: ', jnp.dot(A, B)) | |
def print_ir(f, *args): | |
lowered = jax.jit(f).lower(*args) | |
logging.info(lowered.compiler_ir()) | |
print_ir(matmul_f8_rfc_fp8_inp, A_fp8, B_fp8, A_amax, B_amax, | |
amax(jnp.dot(A, B))) | |
if __name__ == '__main__': | |
absltest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment