Created
July 31, 2020 15:44
-
-
Save khanhnamle1994/8ea65c426968f4973d0a548a39cce337 to your computer and use it in GitHub Desktop.
RBM model architecture
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RBM: | |
def __init__(self, n_vis, n_hid): | |
""" | |
Initialize the parameters (weights and biases) we optimize during the training process | |
:param n_vis: number of visible units | |
:param n_hid: number of hidden units | |
""" | |
# Weights used for the probability of the visible units given the hidden units | |
self.W = torch.randn(n_hid, n_vis) # torch.rand: random normal distribution mean = 0, variance = 1 | |
# Bias probability of the visible units is activated, given the value of the hidden units (p_v_given_h) | |
self.v_bias = torch.randn(1, n_vis) # fake dimension for the batch = 1 | |
# Bias probability of the hidden units is activated, given the value of the visible units (p_h_given_v) | |
self.h_bias = torch.randn(1, n_hid) # fake dimension for the batch = 1 | |
def sample_h(self, x): | |
""" | |
Sample the hidden units | |
:param x: the dataset | |
""" | |
# Probability h is activated given that the value v is sigmoid(Wx + a) | |
# torch.mm make the product of 2 tensors | |
# W.t() take the transpose because W is used for the p_v_given_h | |
wx = torch.mm(x, self.W.t()) | |
# Expand the mini-batch | |
activation = wx + self.h_bias.expand_as(wx) | |
# Calculate the probability p_h_given_v | |
p_h_given_v = torch.sigmoid(activation) | |
# Construct a Bernoulli RBM to predict whether an user loves the movie or not (0 or 1) | |
# This corresponds to whether the n_hid is activated or not activated | |
return p_h_given_v, torch.bernoulli(p_h_given_v) | |
def sample_v(self, y): | |
""" | |
Sample the visible units | |
:param y: the dataset | |
""" | |
# Probability v is activated given that the value h is sigmoid(Wx + a) | |
wy = torch.mm(y, self.W) | |
# Expand the mini-batch | |
activation = wy + self.v_bias.expand_as(wy) | |
# Calculate the probability p_v_given_h | |
p_v_given_h = torch.sigmoid(activation) | |
# Construct a Bernoulli RBM to predict whether an user loves the movie or not (0 or 1) | |
# This corresponds to whether the n_vis is activated or not activated | |
return p_v_given_h, torch.bernoulli(p_v_given_h) | |
def train(self, v0, vk, ph0, phk): | |
""" | |
Perform contrastive divergence algorithm to optimize the weights that minimize the energy | |
This maximizes the log-likelihood of the model | |
""" | |
# Approximate the gradients with the CD algorithm | |
self.W += (torch.mm(v0.t(), ph0) - torch.mm(vk.t(), phk)).t() | |
# Add (difference, 0) for the tensor of 2 dimensions | |
self.v_bias = torch.sum((v0 - vk), 0) | |
self.h_bias = torch.sum((ph0 - phk), 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment