Skip to content

Instantly share code, notes, and snippets.

@rorybyrne
Last active August 30, 2022 09:37
Show Gist options
  • Save rorybyrne/dd1ce80f8b17918ade3aed38f3e04b0b to your computer and use it in GitHub Desktop.
Save rorybyrne/dd1ce80f8b17918ade3aed38f3e04b0b to your computer and use it in GitHub Desktop.
Algorithm to recursively generate a hierarchical modular connectivity mask using PyTorch
"""A function to generate a modular connectivity mask.
You can use it to mask a weight matrix, and then mask the gradients using a hook to fully disable those connections.
self.hh = Parameter(torch.normal(...))
mask = hierarchical_modular_mask(200, level=3, density=0.05, scale=2.0)
self.hh.data *= mask
self.hh.register_hook(lambda grad: grad * mask)
"""
def hierarchical_modular_mask(
size: int,
*,
level: int = 3,
density: float = 0.05,
scale: float = 2.0,
dale_mask: Optional[Tensor] = None,
inhibitory_distance: int = 0
):
"""Generate a connectivity matrix mask of hierarchical modules.
Modules are densely connected locally, with increasing sparsity to nodes in more distal modules. If a
`dale_mask` is given, the inhibitory units (0 values in the mask) will be disabled for all but the
immediately local module. This gives a "local inhibition" effect, with long-range excitatory
connections. If you want more distal inhibition, you can configure it with the `inhibitory_distance`
argument.
Args:
size: The side-length of the square mask.
level: The number of hierarchical levels to generate.
density: The density of the highest level of the hierarchy (i.e.g the off-diagonal regions).
scale: The scaling factor for the density at each subsequent level of the hierarchy.
dale_mask: A tensor of shape (size), specifying a 1 for positive units and a 0 for negative units.
inhibitory_distance: A number specifying the number of hierarchical steps inhibitory connections may span.
Returns:
A (size, size) mask matrix of 1's and 0's
ref: https://www.nature.com/articles/srep22057
"""
assert inhibitory_distance <= level, "inhibitory_distance cannot be larger than the total number of hierarchical levels"
first_half = math.ceil(size/2)
second_half = math.floor(size/2)
bg = (torch.rand((size, size)) <= density).float()
blank_diag = 1 - torch.block_diag(
torch.ones(first_half, first_half),
torch.ones(second_half, second_half),
)
if level > 0:
inner_diag = torch.zeros(size, size) + torch.block_diag(
hierarchical_modular_mask(
first_half, level-1, density * scale, scale, dale_mask[:first_half] if dale_mask is not None else dale_mask
),
hierarchical_modular_mask(
second_half, level-1, density * scale, scale, dale_mask[first_half:] if dale_mask is not None else dale_mask
),
)
else:
inner_diag = torch.zeros(size, size) + torch.block_diag(
(torch.rand(first_half, first_half) <= density * scale).float(),
(torch.rand(second_half, second_half) <= density * scale).float(),
)
if dale_mask is not None and level > inhibitory_distance:
# Disable non-local inhibitory connections
bg *= dale_mask
mask = (bg * blank_diag) + inner_diag
return mask
@rorybyrne
Copy link
Author

rorybyrne commented Aug 27, 2022

Result:

@rorybyrne
Copy link
Author

rorybyrne commented Aug 28, 2022

Now with Dale's Law/local inhibition. Columns with blue are inhibitory, and do not have any synapses outside their inner local module:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment