Created
April 22, 2023 17:04
-
-
Save zhangqiaorjc/6a254da46570335695f0c1095b96d837 to your computer and use it in GitHub Desktop.
bfloat16-training.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/zhangqiaorjc/6a254da46570335695f0c1095b96d837/bfloat16-training.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "2ckXqCO6zooS" | |
}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"from flax import linen as nn\n", | |
"import optax" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Model(nn.Module):\n", | |
" @nn.compact\n", | |
" def __call__(self, x):\n", | |
" # w initialized to f32\n", | |
" w = self.param('w', nn.initializers.lecun_normal(),\n", | |
" (x.shape[-1], 2), jnp.float32)\n", | |
" # x and w are bf16 in step func\n", | |
" return x @ w\n", | |
"\n", | |
"m = Model()\n", | |
"x_f32 = jnp.ones((1, 2), dtype=jnp.float32)\n", | |
"w_f32 = m.init(jax.random.PRNGKey(0), x_f32)\n", | |
"\n", | |
"optimzer = optax.adamw(0.001)\n", | |
"opt_state_f32 = optimzer.init(w_f32)\n", | |
"\n", | |
"def loss(w_f32, x_f32):\n", | |
" w_bf16, x_bf16 = jax.tree_map(lambda x: x.astype(jnp.bfloat16),\n", | |
" (w_f32, x_f32))\n", | |
" loss_bf16 = m.apply(w_bf16, x_bf16).sum()\n", | |
" return loss_bf16\n", | |
"\n", | |
"grad_f = jax.grad(loss)\n", | |
"\n", | |
"def step(w_f32, x_f32, opt_state_f32):\n", | |
" grads_f32 = grad_f(w_f32, x_f32)\n", | |
" updates_f32, new_opt_state_f32 = optimzer.update(grads_f32, opt_state_f32, w_f32)\n", | |
" new_w_f32 = optax.apply_updates(w_f32, updates_f32)\n", | |
" return new_w_f32, new_opt_state_f32\n", | |
"\n", | |
"new_w_f32, new_opt_state_f32 = step(w_f32, x_f32, opt_state_f32)\n", | |
"print('new_w_f32=', new_w_f32)\n", | |
"print('new_opt_state_f32=', new_opt_state_f32)\n" | |
], | |
"metadata": { | |
"id": "t3pJCg2kzwDU", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "5e0017c2-f4c9-4ab2-b3b8-6a5c5f288c8d" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"new_w_f32= FrozenDict({\n", | |
" params: {\n", | |
" w: Array([[-1.10556 , -1.1878285],\n", | |
" [-0.9265006, 0.1304449]], dtype=float32),\n", | |
" },\n", | |
"})\n", | |
"new_opt_state_f32= (ScaleByAdamState(count=Array(1, dtype=int32), mu=FrozenDict({\n", | |
" params: {\n", | |
" w: Array([[0.1, 0.1],\n", | |
" [0.1, 0.1]], dtype=float32),\n", | |
" },\n", | |
"}), nu=FrozenDict({\n", | |
" params: {\n", | |
" w: Array([[0.001, 0.001],\n", | |
" [0.001, 0.001]], dtype=float32),\n", | |
" },\n", | |
"})), EmptyState(), EmptyState())\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "_Qz32arW46v3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment