Last active
September 18, 2022 12:19
-
-
Save yulkang/2e4fc3061b45403f455d7f4c316ab168 to your computer and use it in GitHub Desktop.
Block diagonal matrix in PyTorch - vectorized
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A part of the pylabyk library: numpytorch.py at https://github.com/yulkang/pylabyk""" | |
import torch | |
def block_diag(m): | |
""" | |
Make a block diagonal matrix along dim=-3 | |
EXAMPLE: | |
block_diag(torch.ones(4,3,2)) | |
should give a 12 x 8 matrix with blocks of 3 x 2 ones. | |
Prepend batch dimensions if needed. | |
You can also give a list of matrices. | |
:type m: torch.Tensor, list | |
:rtype: torch.Tensor | |
""" | |
if type(m) is list: | |
m = torch.cat([m1.unsqueeze(-3) for m1 in m], -3) | |
d = m.dim() | |
n = m.shape[-3] | |
siz0 = m.shape[:-3] | |
siz1 = m.shape[-2:] | |
m2 = m.unsqueeze(-2) | |
eye = attach_dim(torch.eye(n).unsqueeze(-2), d - 3, 1) | |
return (m2 * eye).reshape( | |
siz0 + torch.Size(torch.tensor(siz1) * n) | |
) | |
def attach_dim(v, n_dim_to_prepend=0, n_dim_to_append=0): | |
return v.reshape( | |
torch.Size([1] * n_dim_to_prepend) | |
+ v.shape | |
+ torch.Size([1] * n_dim_to_append)) |
For my own record, this addresses the following questions:
https://discuss.pytorch.org/t/creating-a-block-diagonal-matrix/17357
https://discuss.pytorch.org/t/creating-a-block-diagonal-matrix/22592
https://stackoverflow.com/questions/54856333/pytorch-diagonal-matrix-block-set-efficiently/56638727#56638727
Example:
>>> block_diag(torch.ones(4,3,2))
tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 1., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.]])
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
For an up-to-date version, check numpytorch.py in my
pylabyk
library: https://github.com/yulkang/pylabyk