Last active
January 9, 2020 11:41
-
-
Save mikechen66/d718828f1a074996fa3d22d92f9bb2cb to your computer and use it in GitHub Desktop.
This gist is a snippet of the standard LSTM code. It is hard to get update LSTM snippets in the community. So I provide it for the reference in building a complete project.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# standard_lstm.py | |
import tensorflow as tf | |
from tensorflow.contrib import rnn | |
from tensorflow.contrib import legacy_seq2seq | |
from tensorflow.python.framework import ops | |
from tensorflow.python.ops import math_ops | |
from tensorflow.python.ops.math_ops import sigmoid | |
from tensorflow.python.ops.math_ops import tanh | |
import numpy as np | |
def standard_lstm(inputs, init_h, init_c, input_kernel, hidden_kernel, bias, | |
tanh_activation, sigm_activation, mask, time_major, | |
go_backwards): | |
"""LSTM with standard kernel implementation. | |
Args: | |
inputs: input tensor of LSTM layer. | |
init_h: initial state tensor for the cell output. | |
init_c: initial state tensor for the cell hidden state. | |
input_kernel: weights for the input kernel with regard to the input such xt. | |
hidden_kernel: weights for hidden kernel with the hidden state. | |
bias: weights for input kernel bias and hidden bias. Only hidden bias | |
is used in this case. | |
tanh_activation: Tanh activation function to use for both the j gate and the | |
output. | |
sigm_activation: Sigmoid activatin function to use for hidden recurrent state. | |
mask: Boolean tensor for mask out the steps within sequence. | |
time_major: boolean of ture or false, whether the inputs are in the format of | |
[time, batch, feature] or [batch, time, feature]. | |
go_backwards: Boolean (default False). If True, process the input sequence | |
backwards and return the reversed sequence. | |
Returns: | |
last_output: output tensor for the last timestep, which has shape | |
[batch, units]. | |
outputs: output tensor for all timesteps, which has shape | |
[batch, time, units]. | |
state_0: the cell output, which has same shape as init_h. | |
state_1: the cell hidden state, which has same shape as init_c. | |
runtime: constant string tensor which indicate real runtime hardware. This | |
value is for testing purpose and should be used by user. | |
""" | |
imput_shape = K.int_shape(inputs) | |
timesteps = input_shape[0] if time_major else input_shape[1] | |
def step(cell_inputs, cell_states): | |
# Step function that will be used by Keras RNN backend. | |
h_t1 = cell_state[0] # previous hidden state | |
c_t1 = cell_state[1] # previous cell state | |
# The x_t is one of the cell_inputs for the LSTMCell. | |
z = K.dot(cell_inputs, input_kernel) | |
z +=K.dot(h_t1, hidden_kernel) | |
z = K.bias_add(z, bias) | |
z0, z1, z2, z3 = array_ops.split(z, 4, axis=1) | |
i = sigm_activation(z0) | |
f = sigm_activation(z1) | |
j = tanh_activation(z2) | |
o = sigm_activation(z3) | |
c = f * c_t1 + i * j | |
h = o * tanh_activation(c) | |
# Please notice h denotes the sequence and h,c] denotes the state. | |
# of both h and c. | |
return h, [h,c] | |
last_output, outputs, new_states = K.rnn( | |
step, | |
inputs, [init_h, init_c], | |
constants=None, | |
unroll=False, | |
time=mask, | |
go_backwards=go_backwards, | |
input_length=timesteps) | |
return last_output, outputs, new_states[0], new_states[1], _runtime('cpu') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# standard_lstm.py | |
""" | |
LSTM is structured as the major components inlcuding hidden state, gates and cell | |
state. The gates include forget gate(f), input gate(i), joint gate(j) and output | |
gate(o). The previous sequential inputs of h_t1 multiplies hidden_kernel along the | |
hidden state chain. The previoius c_t1 get into the LSTMCell aong the cell state | |
chain. The input x_t multiplies the input_kernel. And then the LSTMCell concatenates | |
the multiplicated results with a shared bias with the same dimensions. Afterwards, | |
apply three sigmoid(σ) functions and one tanh function onto the four new vectors | |
in order to obtain the vectors of z0, z1,z2 and z3 for f, i, j and o respectively. | |
There are four formulas as follows. | |
z0 = σ*(Wf*h_t1 + Uf*xt + bf); | |
z1 = σ*(Wi*h_t1 + Ui*xt + bi); | |
z2 = tanh(Wj*h_t1 + Uj*xt + bj); | |
z3 = σ*(Wo*h_t1 + Uo*xt + bo); | |
where Wf, Wi, Wj and Wo denotes the hidden_kernel, and Uf, Ui, Uj and Uo denotes | |
the input_kernel. Both W and U are matrices. bf, bi, bj and bo represents the the | |
bias respectively. The above-mentioned notation in the formulas are not written | |
into the code but used for understanding. | |
Tensorflow has not explanined the update changes of two kind of kernels. So The | |
kernels and its related multiplication looks like black boxes. So I define the | |
hidden_knernel for the the preious hidden state and the input_kernel for the | |
input x_t for easy understanding. Please notice that the the bias is shared to | |
be used. | |
And then multiplicates the pairs of f and c_t1, i and j. Furthermore, add the two | |
multiplicated results together to get the output c(that is denoted as ct in the | |
general papers). I list the formulas as follows. | |
c = f * c_t1 + i * j; | |
Moreover, apply the acvivation function of tanh onto c to get tanh(c). And next, | |
mulitply o and tanh(c) to get the output hidden state of h(denoted as ht in the | |
general papers usually). Here is the formula as follows. | |
h = o * tanh(c) | |
It returns h to denotes the hidden state sequence and the pair of [h,c] as the | |
output state of the current LSTMCell. Therefore, we can see both the hidden state | |
sequence and and the pair of [h,c] state message. However, the form of [h,c]looks | |
like a sequence so that developers are easy to misunderstand its usage. | |
Please notify that the forget gate is actualluy a remembering gate. It remembers 1 | |
and forget 0. | |
colah: https://colah.github.io/posts/2015-08-Understanding-LSTMs/ | |
RECURRENT NEURAL NETWORK REGULARIZATION: https://arxiv.org/abs/1409.2329 | |
Toronto: http://www.cs.toronto.edu/~rgrosse/courses/csc321_2017/readings/L15%20Exploding%20and%20Vanishing%20Gradients.pdf | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment