Last active
September 20, 2019 11:46
-
-
Save alper111/ef29c006cb35aec03ecbae9866b30e23 to your computer and use it in GitHub Desktop.
PyTorch implementation of soft decision tree. All gating calculations are done at one step in order to utilize from GPU. The recursive definition might be faster for non-GPU machines.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class SoftTree(torch.nn.Module): | |
def __init__(self, in_features, out_features, depth, projection='constant', dropout=0.0): | |
super(SoftTree, self).__init__() | |
self.proj = projection | |
self.depth = depth | |
self.in_features = in_features | |
self.out_features = out_features | |
self.leaf_count = int(2**depth) | |
self.gate_count = int(self.leaf_count - 1) | |
self.gw = torch.nn.Parameter( | |
torch.nn.init.kaiming_normal_( | |
torch.empty(self.gate_count, in_features), nonlinearity='sigmoid').t()) | |
self.gb = torch.nn.Parameter(torch.zeros(self.gate_count)) | |
# dropout rate for gating weights. | |
self.drop = torch.nn.Dropout(p=dropout) | |
if self.proj == 'linear': | |
self.pw = torch.nn.init.kaiming_normal_(torch.empty(out_features*self.leaf_count, in_features), nonlinearity='linear') | |
self.pw = torch.nn.Parameter(self.pw.view(out_features, self.leaf_count, in_features).permute(0, 2, 1)) | |
self.pb = torch.nn.Parameter(torch.zeros(out_features, self.leaf_count)) | |
elif self.proj == 'constant': | |
# find a better init for this. | |
self.z = torch.nn.Parameter(torch.randn(out_features, self.leaf_count)) | |
def forward(self, x): | |
node_densities = self.node_densities(x) | |
leaf_probs = node_densities[:, -self.leaf_count:].t() | |
if self.proj == 'linear': | |
gated_projection = torch.matmul(self.pw,leaf_probs).permute(2,0,1) | |
gated_bias = torch.matmul(self.pb,leaf_probs).permute(1,0) | |
result = torch.matmul(gated_projection,x.view(-1,self.in_features,1))[:,:,0] + gated_bias | |
elif self.proj == 'constant': | |
result = torch.matmul(self.z,leaf_probs).permute(1,0) | |
return result | |
def extra_repr(self): | |
return "in_features=%d, out_features=%d, depth=%d, projection=%s" % ( | |
self.in_features, | |
self.out_features, | |
self.depth, | |
self.proj) | |
def node_densities(self, x): | |
gw_ = self.drop(self.gw) | |
gatings = torch.sigmoid(torch.add(torch.matmul(x,gw_),self.gb)) | |
node_densities = torch.ones(x.shape[0], 2**(self.depth+1)-1, device=x.device) | |
it = 1 | |
for d in range(1, self.depth+1): | |
for i in range(2**d): | |
parent_index = (it+1) // 2 - 1 | |
child_way = (it+1) % 2 | |
if child_way == 0: | |
parent_gating = gatings[:, parent_index] | |
else: | |
parent_gating = 1 - gatings[:, parent_index] | |
parent_density = node_densities[:, parent_index].clone() | |
node_densities[:, it] = (parent_density * parent_gating) | |
it += 1 | |
return node_densities | |
def gatings(self, x): | |
with torch.no_grad(): | |
gatings = torch.sigmoid(torch.add(torch.matmul(x,self.gw),self.gb)) | |
return gatings | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment