Created
November 10, 2023 05:27
-
-
Save zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589 to your computer and use it in GitHub Desktop.
sincos remat example.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/zhangqiaorjc/73cc154f22cbae474f6959d9a9fc8589/sincos-remat-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "wgXLGxV2-nfI" | |
}, | |
"outputs": [], | |
"source": [ | |
"import jax\n", | |
"from jax import core\n", | |
"import jax.numpy as jnp" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sincos_p = core.Primitive('sincos')\n", | |
"sincos_p.multiple_results = True" | |
], | |
"metadata": { | |
"id": "OXui8YA2-uF4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@sincos_p.def_impl\n", | |
"def sincos_impl(x):\n", | |
" return jnp.sin(x), jnp.cos(x)" | |
], | |
"metadata": { | |
"id": "eeeD56dS-8Lc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@sincos_p.def_abstract_eval\n", | |
"def sincos_abstract_eval(x):\n", | |
" return x, x" | |
], | |
"metadata": { | |
"id": "m0L5ow22_AmN" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def sincos(x):\n", | |
" return sincos_p.bind(x)" | |
], | |
"metadata": { | |
"id": "0xyg3Jg__Q8C" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jax.make_jaxpr(sincos)(5)" | |
], | |
"metadata": { | |
"id": "m1cFiNCb_LbK", | |
"outputId": "5f464fa0-b521-4f16-98d3-4bea39ec3850" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:i32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\u001b[39m\u001b[22m\u001b[22m b\u001b[35m:i32[]\u001b[39m c\u001b[35m:i32[]\u001b[39m = sincos a \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(b, c) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@jax.custom_vjp\n", | |
"def sin(x):\n", | |
" return jnp.sin(x)\n", | |
"\n", | |
"def sin_fwd(x):\n", | |
" return sincos(x)\n", | |
"\n", | |
"def sin_bwd(res, g):\n", | |
" return (res * g,)\n", | |
"\n", | |
"sin.defvjp(sin_fwd, sin_bwd)" | |
], | |
"metadata": { | |
"id": "w6pG4Q0__PdI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jax.make_jaxpr(jax.grad(sin))(5.0)" | |
], | |
"metadata": { | |
"id": "MUZNly7x_vWW", | |
"outputId": "34e730b9-6b0f-4794-c2ba-a384f2fc9813" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22m_\u001b[35m:f32[]\u001b[39m b\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
" c\u001b[35m:f32[]\u001b[39m = mul b 1.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def loss(x):\n", | |
" return jnp.exp(sin(x))" | |
], | |
"metadata": { | |
"id": "5tgmUiTH_2aH" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jax.make_jaxpr(jax.grad(loss))(5.0)" | |
], | |
"metadata": { | |
"id": "NW80ucjxAK1z", | |
"outputId": "21839d74-de71-4e48-d91e-222b199b4705" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
" d\u001b[35m:f32[]\u001b[39m = exp b\n", | |
" e\u001b[35m:f32[]\u001b[39m = mul 1.0 d\n", | |
" f\u001b[35m:f32[]\u001b[39m = mul c e\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@jax.checkpoint\n", | |
"def loss(x):\n", | |
" return jnp.exp(sin(x))\n", | |
"\n", | |
"jax.make_jaxpr(jax.grad(loss))(5.0)" | |
], | |
"metadata": { | |
"id": "C1swFN6MAMRa", | |
"outputId": "5a81c4bd-d466-422d-b52a-762441006c91" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m _\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
" _\u001b[35m:f32[]\u001b[39m = exp b\n", | |
" c\u001b[35m:f32[]\u001b[39m = remat2[\n", | |
" differentiated=True\n", | |
" jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; d\u001b[35m:f32[]\u001b[39m e\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mf\u001b[35m:f32[]\u001b[39m g\u001b[35m:f32[]\u001b[39m = sincos d\n", | |
" h\u001b[35m:f32[]\u001b[39m = exp f\n", | |
" i\u001b[35m:f32[]\u001b[39m = mul e h\n", | |
" j\u001b[35m:f32[]\u001b[39m = mul g i\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(j,) }\n", | |
" policy=None\n", | |
" prevent_cse=True\n", | |
" ] a 1.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(c,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 20 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def loss(x):\n", | |
" x = jax._src.ad_checkpoint.checkpoint_name(sin(x), 'sin(x)')\n", | |
" return jnp.exp(x)\n", | |
"\n", | |
"loss = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_only_these_names('sin(x)'))\n", | |
"\n", | |
"jax.make_jaxpr(jax.grad(loss))(5.0)" | |
], | |
"metadata": { | |
"id": "tZP1R5t7AWU8", | |
"outputId": "48d1919c-7d02-4c98-c169-d454239b8ed4" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m _\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
" c\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] b\n", | |
" _\u001b[35m:f32[]\u001b[39m = exp c\n", | |
" d\u001b[35m:f32[]\u001b[39m = remat2[\n", | |
" differentiated=True\n", | |
" jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; e\u001b[35m:f32[]\u001b[39m f\u001b[35m:f32[]\u001b[39m g\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22m_\u001b[35m:f32[]\u001b[39m h\u001b[35m:f32[]\u001b[39m = sincos f\n", | |
" i\u001b[35m:f32[]\u001b[39m = exp e\n", | |
" j\u001b[35m:f32[]\u001b[39m = mul g i\n", | |
" k\u001b[35m:f32[]\u001b[39m = mul h j\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(k,) }\n", | |
" policy=<function save_only_these_names.<locals>.policy at 0x7f531ed7b130>\n", | |
" prevent_cse=True\n", | |
" ] c a 1.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(d,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 23 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"@jax.custom_vjp\n", | |
"def sin(x):\n", | |
" return jnp.sin(x)\n", | |
"\n", | |
"def sin_fwd(x):\n", | |
" sinx, cosx = sincos(x)\n", | |
" sinx = jax._src.ad_checkpoint.checkpoint_name(sinx, 'sin(x)')\n", | |
" cosx = jax._src.ad_checkpoint.checkpoint_name(cosx, 'sin(x)')\n", | |
" return sinx, cosx\n", | |
"\n", | |
"def sin_bwd(res, g):\n", | |
" return (res * g,)\n", | |
"\n", | |
"sin.defvjp(sin_fwd, sin_bwd)\n", | |
"\n", | |
"def loss(x):\n", | |
" return jnp.exp(sin(x))\n", | |
"\n", | |
"loss = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_only_these_names('sin(x)'))\n", | |
"\n", | |
"jax.make_jaxpr(jax.grad(loss))(5.0)" | |
], | |
"metadata": { | |
"id": "QbiqHCecAr_l", | |
"outputId": "31d4a105-df41-4e7b-f17c-814a19ad6a2c" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; a\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mb\u001b[35m:f32[]\u001b[39m c\u001b[35m:f32[]\u001b[39m = sincos a\n", | |
" d\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] b\n", | |
" e\u001b[35m:f32[]\u001b[39m = name[name=sin(x)] c\n", | |
" _\u001b[35m:f32[]\u001b[39m = exp d\n", | |
" f\u001b[35m:f32[]\u001b[39m = remat2[\n", | |
" differentiated=True\n", | |
" jaxpr={ \u001b[34m\u001b[22m\u001b[1mlambda \u001b[39m\u001b[22m\u001b[22m; g\u001b[35m:f32[]\u001b[39m h\u001b[35m:f32[]\u001b[39m i\u001b[35m:f32[]\u001b[39m. \u001b[34m\u001b[22m\u001b[1mlet\n", | |
" \u001b[39m\u001b[22m\u001b[22mj\u001b[35m:f32[]\u001b[39m = exp h\n", | |
" k\u001b[35m:f32[]\u001b[39m = mul i j\n", | |
" l\u001b[35m:f32[]\u001b[39m = mul g k\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(l,) }\n", | |
" policy=<function save_only_these_names.<locals>.policy at 0x7f531ed7ba30>\n", | |
" prevent_cse=True\n", | |
" ] e d 1.0\n", | |
" \u001b[34m\u001b[22m\u001b[1min \u001b[39m\u001b[22m\u001b[22m(f,) }" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 25 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "t0IEhLtICZY4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment