Created
December 19, 2024 20:03
-
-
Save dmelcer9/286d5fd674111380107b321998313f34 to your computer and use it in GitHub Desktop.
A nice multi-level indexing function on tensors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def nice_index(tens: Tensor, index: Tensor) -> Tensor: | |
""" | |
A multi-level indexing function. | |
The last dimension of the index is a multi-level specifier | |
All other dimensions of the index are batch | |
For example, if the index is [[[1, 2], [3, 4]], [[5, 6], [7, 8]]] | |
the result will be [[tens[1, 2], tens[3, 4]], [tens[5, 6], tens[7, 8]]] | |
See tests for more details | |
""" | |
num_levels_of_index = index.shape[-1] | |
dims_remaining_of_tens = len(tens.shape) - num_levels_of_index | |
if dims_remaining_of_tens < 0: | |
raise IndexError( | |
f"Index too long: {num_levels_of_index} levels on a {len(tens)}-dimensional tensor" | |
) | |
output_shape = index.shape[:-1] + tens.shape[num_levels_of_index:] | |
batch_shape = index.shape[:-1] | |
batch_flat_size = functools.reduce(mul, batch_shape, 1) | |
flat_batch_index = index.reshape(batch_flat_size, num_levels_of_index) | |
if flat_batch_index.numel() > 0: | |
result = tens[tuple(map(tuple, flat_batch_index.T.long().tolist()))] | |
else: | |
result = tens.unsqueeze(0).tile(batch_flat_size) | |
return result.reshape(output_shape) | |
class TestNiceIndex(unittest.TestCase): | |
def expect(self, tens, index, output): | |
o = nice_index(torch.as_tensor(tens), torch.as_tensor(index)) | |
self.assertEqual(o.tolist(), output) | |
def test_1d_1d(self): | |
self.expect([1, 2, 3, 4], [3], 4) | |
def test_1d_2d(self): | |
self.expect([1, 2, 3, 4], [[2], [1]], [3, 2]) | |
def test_1d_3d(self): | |
self.expect([0, 1, 2, 3], [[[1], [2]], [[3], [0]]], [[1, 2], [3, 0]]) | |
def test_1d_0d(self): | |
self.expect([0, 1, 2, 3], [], [0, 1, 2, 3]) | |
self.expect([0, 1, 2, 3], [[], []], [[0, 1, 2, 3], [0, 1, 2, 3]]) | |
def test_2d_1d(self): | |
self.expect([[0, 1], [2, 3]], [1], [2, 3]) | |
self.expect([[0, 1], [2, 3]], [1, 1], 3) | |
def test_2d_2d(self): | |
self.expect([[0, 1], [2, 3]], [[1], [0]], [[2, 3], [0, 1]]) | |
self.expect([[0, 1], [2, 3]], [[1, 0], [0, 1]], [2, 1]) | |
def test_2d_3d(self): | |
self.expect( | |
[[0, 1], [2, 3]], [[[0, 1], [1, 0]], [[1, 1], [0, 0]]], [[1, 2], [3, 0]] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The core indexing comes from https://stackoverflow.com/a/52092603, the rest of the function is handling more complex batch shapes