Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Created April 22, 2023 17:04
Show Gist options
  • Save zhangqiaorjc/6a254da46570335695f0c1095b96d837 to your computer and use it in GitHub Desktop.
Save zhangqiaorjc/6a254da46570335695f0c1095b96d837 to your computer and use it in GitHub Desktop.
bfloat16-training.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/zhangqiaorjc/6a254da46570335695f0c1095b96d837/bfloat16-training.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "2ckXqCO6zooS"
},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"from flax import linen as nn\n",
"import optax"
]
},
{
"cell_type": "code",
"source": [
"class Model(nn.Module):\n",
" @nn.compact\n",
" def __call__(self, x):\n",
" # w initialized to f32\n",
" w = self.param('w', nn.initializers.lecun_normal(),\n",
" (x.shape[-1], 2), jnp.float32)\n",
" # x and w are bf16 in step func\n",
" return x @ w\n",
"\n",
"m = Model()\n",
"x_f32 = jnp.ones((1, 2), dtype=jnp.float32)\n",
"w_f32 = m.init(jax.random.PRNGKey(0), x_f32)\n",
"\n",
"optimzer = optax.adamw(0.001)\n",
"opt_state_f32 = optimzer.init(w_f32)\n",
"\n",
"def loss(w_f32, x_f32):\n",
" w_bf16, x_bf16 = jax.tree_map(lambda x: x.astype(jnp.bfloat16),\n",
" (w_f32, x_f32))\n",
" loss_bf16 = m.apply(w_bf16, x_bf16).sum()\n",
" return loss_bf16\n",
"\n",
"grad_f = jax.grad(loss)\n",
"\n",
"def step(w_f32, x_f32, opt_state_f32):\n",
" grads_f32 = grad_f(w_f32, x_f32)\n",
" updates_f32, new_opt_state_f32 = optimzer.update(grads_f32, opt_state_f32, w_f32)\n",
" new_w_f32 = optax.apply_updates(w_f32, updates_f32)\n",
" return new_w_f32, new_opt_state_f32\n",
"\n",
"new_w_f32, new_opt_state_f32 = step(w_f32, x_f32, opt_state_f32)\n",
"print('new_w_f32=', new_w_f32)\n",
"print('new_opt_state_f32=', new_opt_state_f32)\n"
],
"metadata": {
"id": "t3pJCg2kzwDU",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5e0017c2-f4c9-4ab2-b3b8-6a5c5f288c8d"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"new_w_f32= FrozenDict({\n",
" params: {\n",
" w: Array([[-1.10556 , -1.1878285],\n",
" [-0.9265006, 0.1304449]], dtype=float32),\n",
" },\n",
"})\n",
"new_opt_state_f32= (ScaleByAdamState(count=Array(1, dtype=int32), mu=FrozenDict({\n",
" params: {\n",
" w: Array([[0.1, 0.1],\n",
" [0.1, 0.1]], dtype=float32),\n",
" },\n",
"}), nu=FrozenDict({\n",
" params: {\n",
" w: Array([[0.001, 0.001],\n",
" [0.001, 0.001]], dtype=float32),\n",
" },\n",
"})), EmptyState(), EmptyState())\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "_Qz32arW46v3"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment