Created
November 30, 2022 19:38
-
-
Save zhangqiaorjc/25c9d753864cd95942ffd9ba48b63f42 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Runs a simple mnist model with fake FP8. FP8 scaling is used. | |
The HLO can be dumped by setting the environment variable: | |
XLA_FLAGS='--xla_dump_disable_metadata=true --xla_dump_to=/tmp/hlo' | |
""" | |
import tensorflow as tf | |
USE_QUANT = True | |
tf.keras.utils.set_random_seed(1) | |
# Fake FP8 dtypes since we don't yet have real FP8 | |
FAKE_E4M3 = tf.float16 | |
FAKE_E5M2 = tf.bfloat16 | |
E4M3_MAX = 448. | |
E5M2_MAX = 57344. | |
def get_fp8_max(fake_dtype): | |
if fake_dtype == FAKE_E4M3: | |
return E4M3_MAX | |
else: | |
assert fake_dtype == FAKE_E5M2 | |
return E5M2_MAX | |
def quantize(x, quantized_dtype, scale): | |
dtype_max = get_fp8_max(quantized_dtype) | |
scaled_x = tf.clip_by_value(x / scale, -dtype_max, dtype_max) | |
return tf.cast(scaled_x, quantized_dtype) | |
def dequantize(x, wide_dtype, scale): | |
return tf.cast(x, wide_dtype) * scale | |
def quantize_dequantize(x, quantized_dtype, scale): | |
orig_dtype = x.dtype | |
qx = quantize(x, quantized_dtype, scale) | |
return dequantize(qx, orig_dtype, scale) | |
def update_scale(x, quantized_dtype, scale_var): | |
dtype_max = get_fp8_max(quantized_dtype) | |
amax = tf.cast(tf.math.reduce_max(tf.math.abs(x)), scale_var.dtype) | |
amax = tf.maximum(amax, 2 ** -10) | |
scale_var.assign(1.1 * amax / dtype_max) | |
def qdq_and_update(x, dtype, scale_var): | |
qx = quantize_dequantize(x, dtype, scale_var) | |
update_scale(x, dtype, scale_var) | |
return qx | |
class DenseWithScaling(tf.keras.layers.Layer): | |
def __init__(self, units, activation=None, use_quant=False): | |
super().__init__() | |
self.units = int(units) | |
self.actvation = tf.keras.activations.get(activation) | |
self.use_quant = use_quant | |
def build(self, input_shape): | |
last_dim = input_shape[-1] | |
self.kernel = self.add_weight("kernel", shape=[last_dim, self.units], | |
initializer="glorot_uniform") | |
self.bias = self.add_weight("kernel", shape=[self.units], | |
initializer="zeros") | |
if self.use_quant: | |
init32 = tf.keras.initializers.Constant(32.) | |
self.output_scale = self.add_weight("output_scale", shape=(), | |
initializer=init32, trainable=False) | |
self.kernel_scale = self.add_weight("kernel_scale", shape=(), | |
initializer=init32, trainable=False) | |
self.output_grad_scale = self.add_weight("kernel_scale", shape=(), | |
initializer=init32, | |
trainable=False) | |
@tf.custom_gradient | |
def out_qdq(self, out): | |
"""Quantize-dequantize both the output and the output's gradient.""" | |
qout = qdq_and_update(out, FAKE_E4M3, self.output_scale) | |
def grad(out_grad): | |
return qdq_and_update(out_grad, FAKE_E5M2, self.output_grad_scale) | |
return qout, grad | |
@tf.custom_gradient | |
def kernel_qdq(self, kernel): | |
"""Quantize-dequantize the kernel but not its gradient.""" | |
qkernel = qdq_and_update(kernel, FAKE_E4M3, self.kernel_scale) | |
def grad(kernel_grad): | |
return kernel_grad | |
return qkernel, grad | |
def call(self, inputs): | |
kernel = self.kernel.read_value() | |
if self.use_quant: | |
kernel = self.kernel_qdq(kernel) | |
out = inputs @ kernel + self.bias | |
out = self.actvation(out) | |
if self.use_quant: | |
out = self.out_qdq(out) | |
return out | |
class MnistModel(tf.keras.Model): | |
def build(self, input_shape): | |
self.dense1 = DenseWithScaling(64, activation="relu", use_quant=USE_QUANT) | |
self.dense2 = DenseWithScaling(64, activation="relu", use_quant=USE_QUANT) | |
self.dense3 = DenseWithScaling(10, use_quant=USE_QUANT) | |
def call(self, inputs): | |
x = self.dense1(inputs) | |
x = self.dense2(x) | |
output = self.dense3(x) | |
return output | |
model = MnistModel() | |
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() | |
x_train = x_train.reshape(60000, 784).astype("float32") / 255 | |
x_test = x_test.reshape(10000, 784).astype("float32") / 255 | |
model.compile( | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
optimizer=tf.keras.optimizers.RMSprop(), | |
metrics=["accuracy"], | |
# run_eagerly=True, | |
jit_compile=True | |
) | |
history = model.fit(x_train, y_train, batch_size=64, epochs=2, | |
validation_split=0.2, verbose=1) | |
test_scores = model.evaluate(x_test, y_test, verbose=2) | |
print("Test loss:", test_scores[0]) | |
print("Test accuracy:", test_scores[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment