Last active
October 28, 2020 14:15
-
-
Save epignatelli/42662667f7be1ea0daa3e34a375f027a to your computer and use it in GitHub Desktop.
a module decorator for jax.experimental.stax
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, NamedTuple, Callable, Any | |
import functools | |
import jax.numpy as jnp | |
Params = Any | |
RNGKey = jnp.ndarray | |
Shape = Tuple[int] | |
class Module(NamedTuple): | |
init: Callable[[RNGKey, Shape], Tuple[Shape, Params]] | |
apply: Callable[[Params, jnp.ndarray], jnp.ndarray] | |
def module(module_maker): | |
@functools.wraps(module_maker) | |
def fabricate_module(*args, **kwargs): | |
init, apply = module_maker(*args, **kwargs) | |
return Module(init, apply) | |
return fabricate_module | |
# example usage on a Dense function | |
if __name__ == "__main__": | |
import jax | |
from jax.experimental.stax import Dense, glorot_normal, normal | |
@module | |
def Dense(out_dim, W_init=glorot_normal(), b_init=normal()): | |
"""Layer constructor function for a dense (fully-connected) layer.""" | |
def init_fun(rng, input_shape): | |
output_shape = input_shape[:-1] + (out_dim,) | |
k1, k2 = jax.random.split(rng) | |
W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(k2, (out_dim,)) | |
return output_shape, (W, b) | |
def apply_fun(params, inputs, **kwargs): | |
W, b = params | |
return jnp.dot(inputs, W) + b | |
return init_fun, apply_fun | |
seed = 0 | |
rng = jax.random.PRNGKey(seed) | |
input_shape = (-1, 4) | |
dense = Dense(8, W_init=glorot_normal(), b_init=normal()) | |
print(dense) | |
assert dense.init == dense[0] | |
assert dense.apply == dense[1] | |
out_shape, params = dense.init(rng, input_shape) | |
noise = jax.random.normal(rng, (input_shape[-1],)) | |
output = dense.apply(params, noise) | |
print(output.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment