Created
June 13, 2019 09:39
-
-
Save victorkohler/0c161f8dfb790d4d21f3d003d8cdebb4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#------------- | |
# HYPERPARAMS | |
#------------- | |
num_neg = 4 | |
latent_features = 8 | |
epochs = 20 | |
batch_size = 256 | |
learning_rate = 0.001 | |
#------------------------- | |
# TENSORFLOW GRAPH | |
#------------------------- | |
graph = tf.Graph() | |
with graph.as_default(): | |
# Define input placeholders for user, item and label. | |
user = tf.placeholder(tf.int32, shape=(None, 1)) | |
item = tf.placeholder(tf.int32, shape=(None, 1)) | |
label = tf.placeholder(tf.int32, shape=(None, 1)) | |
# User feature embedding | |
u_var = tf.Variable(tf.random_normal([len(users), latent_features], | |
stddev=0.05), name='user_embedding') | |
user_embedding = tf.nn.embedding_lookup(u_var, user) | |
# Item feature embedding | |
i_var = tf.Variable(tf.random_normal([len(items), latent_features], | |
stddev=0.05), name='item_embedding') | |
item_embedding = tf.nn.embedding_lookup(i_var, item) | |
# Flatten our user and item embeddings. | |
user_embedding = tf.keras.layers.Flatten()(user_embedding) | |
item_embedding = tf.keras.layers.Flatten()(item_embedding) | |
# Multiplying our user and item latent space vectors together | |
prediction_matrix = tf.multiply(user_embedding, item_embedding) | |
# Our single neuron output layer | |
output_layer = tf.keras.layers.Dense(1, | |
kernel_initializer="lecun_uniform", | |
name='output_layer')(prediction_matrix) | |
# Our loss function as a binary cross entropy. | |
loss = tf.losses.sigmoid_cross_entropy(label, output_layer) | |
# Train using the Adam optimizer to minimize our loss. | |
opt = tf.train.AdamOptimizer(learning_rate = learning_rate) | |
step = opt.minimize(loss) | |
# Initialize all tensorflow variables. | |
init = tf.global_variables_initializer() | |
session = tf.Session(config=None, graph=graph) | |
session.run(init) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment