Created
May 2, 2015 03:22
-
-
Save anonymous/7d9a56b18a514adf125c to your computer and use it in GitHub Desktop.
Deep RNN : using scan or unrolled expressions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import theano | |
import theano.tensor as T | |
import numpy as np | |
import time | |
use_scan = True # change this to use the unrolled expression | |
m = 2**10 # batch size | |
n = 2**12 # number of hidden units per layer | |
depth = 8 | |
t = 8 # time steps | |
assert theano.config.floatX == 'float32' | |
assert theano.config.optimizer == 'fast_run' | |
np.random.seed(0) | |
def relu(x): return x * (x > 0) | |
def rand(*size): return np.array(np.random.normal(size=size, scale=1e-3), dtype=theano.config.floatX) | |
def init(*size): return theano.shared(rand(*size)) | |
print "symbolic input..." | |
x = T.tensor3() | |
targets = T.tensor3() | |
w = T.matrix() | |
g_out = T.matrix() | |
h0 = [T.matrix() for i in range(depth)] | |
def rnn_step(*args): | |
x_curr = args[0] | |
h_prev = args[1:] | |
h_curr = [] | |
for h in h_prev: | |
h_below = x_curr if len(h_curr) == 0 else h_curr[-1] | |
h_curr += [relu(T.dot(h_below, w) + T.dot(h, w))] | |
return h_curr | |
if use_scan: | |
out, updates = theano.scan(rnn_step, | |
sequences=x, | |
outputs_info=h0, | |
non_sequences=[]) | |
err = ((out[-1] - targets) ** 2).mean() | |
else: | |
err = 0 | |
state = h0 | |
for i in range(t): | |
args = [x[i]] + state | |
state = rnn_step(*args) | |
err += ((state[-1] - targets[i]) ** 2).mean() / t | |
updates = theano.OrderedUpdates() | |
print "allocating..." | |
g_out = init(n, n) | |
x_val = rand(t, m, n) | |
targets_val = rand(t, m, n) | |
w_val = rand(n, n) | |
h0_val = [rand(m, n) for i in range(depth)] | |
print "compiling..." | |
f = theano.function([x, w, targets] + h0, err, updates=updates + [(g_out, T.grad(err, w))]) | |
t0 = time.time() | |
print "running..." | |
f_out = f(x_val, w_val, targets_val, *h0_val) | |
elapsed= time.time() - t0 | |
print f_out, g_out.get_value() | |
GB = 4. * (m * n * (depth+2) + m * n * depth + 2 * n * n * depth) / 1024**3 | |
TFLOPS = 3 * 2 * 2 * m * n * n * t * depth / (time.time() - t0) / 1e12 | |
print "expected memory usage =", GB, "GB, measured TFLOPS =", TFLOPS |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment