Created
January 24, 2020 00:35
-
-
Save yaroslavvb/3335705e63121f3f6e892a4a21bf8a6b to your computer and use it in GitHub Desktop.
index_reduce
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def index_reduce(values, indices, dim): | |
"""Reduce values by selecting a single element in dimension dim: | |
Example below produces rank-2 tensor out of rank-3 values tensor by indexing as follows | |
dim=0: values[index[i,j],i,j] | |
dim=1: values[i,index[i,j],j] | |
dim=2: values[i,j,index[i,j]] | |
When all entries of "indices" are equal to p, the result is equivalent to slicing along that dimension. | |
dim=0: values[p,:,:] | |
dim=1: values[:,p,:] | |
""" | |
assert len(indices.shape) == len(values.shape) - 1 | |
shape = list(values.shape) | |
del shape[dim] | |
assert np.prod(shape) == np.prod(indices.shape), f"not enough indices to reduce" | |
indices = indices.unsqueeze(dim) | |
vals = torch.gather(values, dim, indices) | |
return vals.squeeze(dim) | |
def test_index_reduce(): | |
values = torch.arange(0, 8).reshape(2, 2, 2) | |
pos = 0 | |
indices = pos*torch.ones(2, 2).long() | |
assert torch.allclose(index_reduce(values, indices, 0), values[pos, :, :]) | |
assert torch.allclose(index_reduce(values, indices, 1), values[:, pos, :]) | |
assert torch.allclose(index_reduce(values, indices, 2), values[:, :, pos]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment