Last active
October 6, 2020 10:06
-
-
Save pbloem/99e1c73f891b6175ca1a1c000b7ecd58 to your computer and use it in GitHub Desktop.
Gradient estimators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.distributions as dist | |
## REINFORCE | |
adjacencies, num_edges, targets = load_data(...) | |
opt = ... | |
# parametrize edge weights by normal dist | |
ew_mean = nn.Parameter(torch.randn(num_edges)) | |
ew_std = nn.Parameter(torch.ones(num_edges)) | |
for _ in range(num_epochs): | |
opt.zero_grad() | |
# reparametrized sample from the normal dist | |
ew_dist = dist.Normal(ew_mean, ew_std) | |
edge_weights = ew_dist.rsample() | |
# -- we take a reparametrized sample from the normal distribution (see the VAE lecture) | |
edge_weights = softmax(edge_weights) | |
# -- or spherical normalization or whatever | |
# compute two loss terms | |
gcn_loss = loss(gcn(adjacencies, edge_weights), targets) # regular GCN computation (simple backpropagation is fine here) | |
with torch.no_grad(): | |
kemeny_loss = - alpha * kemeny(adjacencies, edge_weights) | |
# -- computation of the kemeny constant is not easily differentiable, so we do it under torch.no_grad | |
# and use REINFORCE to estimate the gradient. | |
# actual loss | |
# loss = gcn_loss + kemeny_loss | |
# -- we won't get a gradient over this, because kemeny_loss is detached from the comp graph | |
# estimated loss | |
loss = gcn_loss + ew_dist.log_prob(edge_weights) * kemeny_loss | |
# -- To see what happens here, write down the expected gradient of the actual loss under the normal distribution above. | |
# The expectation over the first term can be estimated simply by letting pytorch work out the gradient. The | |
# reparametrization results in a gradient on the ew_mean | |
# -- The expectation over the second term, we rewrite using the score function (so this becomes a score function with a | |
# single sample). The variable `kemeny_loss` is just a constant, but the log probability over the edge_weight will | |
# get a gradient for the REINFORCE loss. | |
# -- Note that the _derivative_ of this second term is the score function (if we see kemeny_loss as a constant). This is | |
# what we're looking for. By adding this loss, we're tricking pytorch into computing the gradient estimate and | |
# backpropagating it. | |
loss.backward() | |
opt.step | |
## SPSA | |
# -- We proceed in the same way: compute the gradient estimate in a detached way (under torch.no_grad()) and add its | |
# integrand to the loss so that pytorch sets the gradient estimate as the gradient of the relevant nodes and | |
# backpropagates from there. | |
edge_weights = nn.Parameter(torch.randn(num_edges)) | |
STD = 1e-7 # size of the perturbation | |
for _ in range(num_epochs): | |
opt.zero_grad() | |
# -- sparse softmax or spherical normalization or whatever | |
edge_weights = softmax(edge_weights) | |
# compute the GCN loss once | |
gcn_loss = loss(gcn(adjacencies, edge_weights), targets) # regular GCN computation (simple backpropagation is fine here) | |
# and the Kemeny loss twice | |
with torch.no_grad(): | |
perturbation = torch.randn(num_edges) * STD | |
# -- This should be Bernoulli for a proper SPSA implementation. | |
edge_weights0, edge_weights1 = edge_weights + perturbation, edge_weights - perturbation | |
normalized0, normalized1 = softmax(edge_weights0), softmax(edge_weights1) | |
kemeny_loss0 = - alpha * kemeny(adjacencies, edge_weights0) | |
kemeny_loss1 = - alpha * kemeny(adjacencies, edge_weights0) | |
# -- computation of the kemeny constant is not easily differentiable, so we do it under torch.no_grad | |
# and use REINFORCE to estimate the gradient. | |
# actual loss | |
# loss = gcn_loss + kemeny_loss | |
# -- we won't get a gradient over this, because kemeny_loss is detached from the comp graph | |
# estimated loss | |
loss = gcn_loss + edge_weights * ((kemeny_loss0 - kemeny_loss1)/ (2.0 * perturbation)) | |
# -- Note that when we take the derivative over the second term, the gradient becomes the SPSA estimate of the gradient. | |
loss.backward() | |
opt.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment