Last active
August 30, 2022 09:37
-
-
Save rorybyrne/dd1ce80f8b17918ade3aed38f3e04b0b to your computer and use it in GitHub Desktop.
Algorithm to recursively generate a hierarchical modular connectivity mask using PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A function to generate a modular connectivity mask. | |
You can use it to mask a weight matrix, and then mask the gradients using a hook to fully disable those connections. | |
self.hh = Parameter(torch.normal(...)) | |
mask = hierarchical_modular_mask(200, level=3, density=0.05, scale=2.0) | |
self.hh.data *= mask | |
self.hh.register_hook(lambda grad: grad * mask) | |
""" | |
def hierarchical_modular_mask( | |
size: int, | |
*, | |
level: int = 3, | |
density: float = 0.05, | |
scale: float = 2.0, | |
dale_mask: Optional[Tensor] = None, | |
inhibitory_distance: int = 0 | |
): | |
"""Generate a connectivity matrix mask of hierarchical modules. | |
Modules are densely connected locally, with increasing sparsity to nodes in more distal modules. If a | |
`dale_mask` is given, the inhibitory units (0 values in the mask) will be disabled for all but the | |
immediately local module. This gives a "local inhibition" effect, with long-range excitatory | |
connections. If you want more distal inhibition, you can configure it with the `inhibitory_distance` | |
argument. | |
Args: | |
size: The side-length of the square mask. | |
level: The number of hierarchical levels to generate. | |
density: The density of the highest level of the hierarchy (i.e.g the off-diagonal regions). | |
scale: The scaling factor for the density at each subsequent level of the hierarchy. | |
dale_mask: A tensor of shape (size), specifying a 1 for positive units and a 0 for negative units. | |
inhibitory_distance: A number specifying the number of hierarchical steps inhibitory connections may span. | |
Returns: | |
A (size, size) mask matrix of 1's and 0's | |
ref: https://www.nature.com/articles/srep22057 | |
""" | |
assert inhibitory_distance <= level, "inhibitory_distance cannot be larger than the total number of hierarchical levels" | |
first_half = math.ceil(size/2) | |
second_half = math.floor(size/2) | |
bg = (torch.rand((size, size)) <= density).float() | |
blank_diag = 1 - torch.block_diag( | |
torch.ones(first_half, first_half), | |
torch.ones(second_half, second_half), | |
) | |
if level > 0: | |
inner_diag = torch.zeros(size, size) + torch.block_diag( | |
hierarchical_modular_mask( | |
first_half, level-1, density * scale, scale, dale_mask[:first_half] if dale_mask is not None else dale_mask | |
), | |
hierarchical_modular_mask( | |
second_half, level-1, density * scale, scale, dale_mask[first_half:] if dale_mask is not None else dale_mask | |
), | |
) | |
else: | |
inner_diag = torch.zeros(size, size) + torch.block_diag( | |
(torch.rand(first_half, first_half) <= density * scale).float(), | |
(torch.rand(second_half, second_half) <= density * scale).float(), | |
) | |
if dale_mask is not None and level > inhibitory_distance: | |
# Disable non-local inhibitory connections | |
bg *= dale_mask | |
mask = (bg * blank_diag) + inner_diag | |
return mask |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Now with Dale's Law/local inhibition. Columns with blue are inhibitory, and do not have any synapses outside their inner local module: